Optimize attention in llama4 (#5127)
This commit is contained in:
@@ -240,9 +240,13 @@ class Llama4Attention(nn.Module):
|
|||||||
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor:
|
||||||
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
floor = torch.floor((positions + 1.0) / self.floor_scale)
|
||||||
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
|
attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0
|
||||||
|
|
||||||
return attn_scale.unsqueeze(-1)
|
return attn_scale.unsqueeze(-1)
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
|
def _mul_attn_scale(self, positions, q):
|
||||||
|
attn_scale = self._get_attn_scale(positions)
|
||||||
|
return (q * attn_scale).to(q.dtype)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -250,27 +254,29 @@ class Llama4Attention(nn.Module):
|
|||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(hidden_states)
|
qkv, _ = self.qkv_proj(hidden_states)
|
||||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
||||||
|
qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||||
|
|
||||||
if self.rotary_emb is not None:
|
if self.rotary_emb is not None:
|
||||||
q, k = self.rotary_emb(positions, q, k)
|
q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1)
|
||||||
|
q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view)
|
||||||
|
assert (q_out_unused is q_view) and (k_out_unused is k_view)
|
||||||
|
del q_view, k_view, q_out_unused, k_out_unused
|
||||||
|
|
||||||
if self.qk_norm is not None:
|
if self.qk_norm is not None:
|
||||||
# TODO: support float
|
# TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later
|
||||||
q = q.reshape(-1, self.head_dim).contiguous().bfloat16()
|
qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16()
|
||||||
k = k.reshape(-1, self.head_dim).contiguous().bfloat16()
|
qk = self.qk_norm(qk).to(torch.bfloat16)
|
||||||
q = self.qk_norm(q).to(q.dtype)
|
qk = qk.reshape(-1, self.q_size + self.kv_size)
|
||||||
k = self.qk_norm(k).to(k.dtype)
|
|
||||||
q = q.reshape(-1, self.q_size)
|
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
|
||||||
k = k.reshape(-1, self.kv_size)
|
|
||||||
|
|
||||||
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
# We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where
|
||||||
# the inference-time temperature tuning function is customized to not affect short context
|
# the inference-time temperature tuning function is customized to not affect short context
|
||||||
# while working at very long context
|
# while working at very long context
|
||||||
# https://arxiv.org/abs/2501.19399
|
# https://arxiv.org/abs/2501.19399
|
||||||
if self.attn_temperature_tuning and not self.use_rope:
|
if self.attn_temperature_tuning and not self.use_rope:
|
||||||
attn_scale = self._get_attn_scale(positions)
|
q = self._mul_attn_scale(positions=positions, q=q)
|
||||||
q = (q * attn_scale).to(q.dtype)
|
|
||||||
|
|
||||||
attn_output = self.attn(q, k, v, forward_batch)
|
attn_output = self.attn(q, k, v, forward_batch)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
|
|||||||
Reference in New Issue
Block a user