Source code for lightning_ir.flash.flash_bert

 1from typing import Tuple
 2
 3import torch
 4from transformers.models.bert.modeling_bert import BertSelfAttention
 5
 6try:
 7    from flash_attn import flash_attn_func
 8except ImportError:
 9    flash_attn_func = None
10
11
[docs] 12def flash_attention_forward( 13 self: BertSelfAttention, 14 hidden_states: torch.Tensor, 15 attention_mask: torch.FloatTensor | None, 16 *args, 17 **kwargs, 18) -> Tuple[torch.Tensor]: 19 query = self.transpose_for_scores(self.query(hidden_states)) 20 key = self.transpose_for_scores(self.key(hidden_states)) 21 value = self.transpose_for_scores(self.value(hidden_states)) 22 23 if attention_mask is not None and not attention_mask.any(): 24 attention_mask = None 25 26 if flash_attn_func is not None and hidden_states.is_cuda and attention_mask is None: 27 context = ( 28 flash_attn_func( 29 query.bfloat16().transpose(1, 2), 30 key.bfloat16().transpose(1, 2), 31 value.bfloat16().transpose(1, 2), 32 self.dropout.p if self.training else 0, 33 ) 34 .transpose(1, 2) 35 .to(query.dtype) 36 ) 37 else: 38 context = torch.nn.functional.scaled_dot_product_attention( 39 query, 40 key, 41 value, 42 attention_mask.to(query.dtype) if attention_mask is not None else None, 43 self.dropout.p if self.training else 0, 44 ) 45 46 context = context.permute(0, 2, 1, 3).contiguous() 47 new_context_shape = context.size()[:-2] + (self.all_head_size,) 48 context = context.view(new_context_shape) 49 return (context,)
50 51 52SELF_ATTENTION_PATTERN = "self"