Source code for lightning_ir.flash.flash_bert
1from typing import Tuple
2
3import torch
4from transformers.models.bert.modeling_bert import BertSelfAttention
5
6try:
7 from flash_attn import flash_attn_func
8except ImportError:
9 flash_attn_func = None
10
11
[docs]
12def flash_attention_forward(
13 self: BertSelfAttention,
14 hidden_states: torch.Tensor,
15 attention_mask: torch.FloatTensor | None,
16 *args,
17 **kwargs,
18) -> Tuple[torch.Tensor]:
19 query = self.transpose_for_scores(self.query(hidden_states))
20 key = self.transpose_for_scores(self.key(hidden_states))
21 value = self.transpose_for_scores(self.value(hidden_states))
22
23 if attention_mask is not None and not attention_mask.any():
24 attention_mask = None
25
26 if flash_attn_func is not None and hidden_states.is_cuda and attention_mask is None:
27 context = (
28 flash_attn_func(
29 query.bfloat16().transpose(1, 2),
30 key.bfloat16().transpose(1, 2),
31 value.bfloat16().transpose(1, 2),
32 self.dropout.p if self.training else 0,
33 )
34 .transpose(1, 2)
35 .to(query.dtype)
36 )
37 else:
38 context = torch.nn.functional.scaled_dot_product_attention(
39 query,
40 key,
41 value,
42 attention_mask.to(query.dtype) if attention_mask is not None else None,
43 self.dropout.p if self.training else 0,
44 )
45
46 context = context.permute(0, 2, 1, 3).contiguous()
47 new_context_shape = context.size()[:-2] + (self.all_head_size,)
48 context = context.view(new_context_shape)
49 return (context,)
50
51
52SELF_ATTENTION_PATTERN = "self"