Optimize attention in llama4 (#5127)
This commit is contained in:
@@ -240,9 +240,13 @@ class Llama4Attention(nn.Module):
|
||||
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
||||
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
|
||||
|
||||
return attn_scale.unsqueeze(-1)
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def _mul_attn_scale(self, positions, q):
|
||||
attn_scale = self._get_attn_scale(positions)
|
||||
return (q * attn_scale).to(q.dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -250,27 +254,29 @@ class Llama4Attention(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
if self.rotary_emb is not None:
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
|
||||
q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
|
||||
assert (q_out_unused is q_view) and (k_out_unused is k_view)
|
||||
del q_view, k_view, q_out_unused, k_out_unused
|
||||
|
||||
if self.qk_norm is not None:
|
||||
# TODO: support float
|
||||
q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
|
||||
k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
|
||||
q = self.qk_norm(q).to(q.dtype)
|
||||
k = self.qk_norm(k).to(k.dtype)
|
||||
q = q.reshape(-1, self.q_size)
|
||||
k = k.reshape(-1, self.kv_size)
|
||||
# TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
|
||||
qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16()
|
||||
qk = self.qk_norm(qk).to(torch.bfloat16)
|
||||
qk = qk.reshape(-1, self.q_size + self.kv_size)
|
||||
|
||||
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
|
||||
|
||||
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
||||
# the inference-time temperature tuning function is customized to not affect short context
|
||||
# while working at very long context
|
||||
# https://arxiv.org/abs/2501.19399
|
||||
if self.attn_temperature_tuning and not self.use_rope:
|
||||
attn_scale = self._get_attn_scale(positions)
|
||||
q = (q * attn_scale).to(q.dtype)
|
||||
q = self._mul_attn_scale(positions=positions, q=q)
|
||||
|
||||
attn_output = self.attn(q, k, v, forward_batch)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
||||
Reference in New Issue
Block a user