From cd7e32e2cb150fbf216c5c05697139c68bab4a8d Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Fri, 11 Apr 2025 15:32:41 +0800 Subject: [PATCH] Optimize attention in llama4 (#5127) --- python/sglang/srt/models/llama4.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index d933f27ae..88c3716f7 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -240,9 +240,13 @@ class Llama4Attention(nn.Module): def _get_attn_scale(self, positions: torch.Tensor) -> torch.Tensor: floor = torch.floor((positions + 1.0) / self.floor_scale) attn_scale = torch.log(floor + 1.0) * self.attn_scale + 1.0 - return attn_scale.unsqueeze(-1) + @torch.compile(dynamic=True, backend=get_compiler_backend()) + def _mul_attn_scale(self, positions, q): + attn_scale = self._get_attn_scale(positions) + return (q * attn_scale).to(q.dtype) + def forward( self, positions: torch.Tensor, @@ -250,27 +254,29 @@ class Llama4Attention(nn.Module): forward_batch: ForwardBatch, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + qk, v = qkv.split([self.q_size + self.kv_size, self.kv_size], dim=-1) if self.rotary_emb is not None: - q, k = self.rotary_emb(positions, q, k) + q_view, k_view = qk.split([self.q_size, self.kv_size], dim=-1) + q_out_unused, k_out_unused = self.rotary_emb(positions, q_view, k_view) + assert (q_out_unused is q_view) and (k_out_unused is k_view) + del q_view, k_view, q_out_unused, k_out_unused if self.qk_norm is not None: - # TODO: support float - q = q.reshape(-1, self.head_dim).contiguous().bfloat16() - k = k.reshape(-1, self.head_dim).contiguous().bfloat16() - q = self.qk_norm(q).to(q.dtype) - k = self.qk_norm(k).to(k.dtype) - q = q.reshape(-1, self.q_size) - k = k.reshape(-1, self.kv_size) + # TODO there are still 2 redundant direct_copy_kernel_cuda for this `reshape` and (in attn backend) q.contiguous(), maybe we can fuse them later + qk = qk.reshape(-1, self.head_dim).contiguous().bfloat16() + qk = self.qk_norm(qk).to(torch.bfloat16) + qk = qk.reshape(-1, self.q_size + self.kv_size) + + q, k = qk.split([self.q_size, self.kv_size], dim=-1) # We are applying temperature tuning (https://arxiv.org/abs/2501.19399) to NoPE layers, where # the inference-time temperature tuning function is customized to not affect short context # while working at very long context # https://arxiv.org/abs/2501.19399 if self.attn_temperature_tuning and not self.use_rope: - attn_scale = self._get_attn_scale(positions) - q = (q * attn_scale).to(q.dtype) + q = self._mul_attn_scale(positions=positions, q=q) attn_output = self.attn(q, k, v, forward_batch) output, _ = self.o_proj(attn_output)