diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 3d432c1cc..8017cefa4 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -677,44 +677,94 @@ class DeepseekV2AttentionMLA(nn.Module): else: return _dispatch_mla_subtype() + def op_prepare(self, state): + state.attn_intermediate_state = self.forward_prepare( + positions=state.positions, + hidden_states=state.pop("hidden_states_after_comm_pre_attn"), + forward_batch=state.forward_batch, + zero_allocator=state.zero_allocator, + ) + + def op_core(self, state): + state.hidden_states_after_attn = self.forward_core( + state.pop("attn_intermediate_state") + ) + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, - ) -> torch.Tensor: - if hidden_states.shape[0] == 0: - assert ( - not self.o_proj.reduce_results - ), "short-circuiting allreduce will lead to hangs" - return hidden_states + ): + s = self.forward_prepare( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + zero_allocator=zero_allocator, + ) + return self.forward_core(s) - attn_forward_method = self.dispatch_attn_forward_method(forward_batch) - - if attn_forward_method == AttnForwardMethod.MHA: - return self.forward_normal(positions, hidden_states, forward_batch) - elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV: - return self.forward_normal_chunked_kv( - positions, hidden_states, forward_batch - ) - elif attn_forward_method == AttnForwardMethod.MLA: - return self.forward_absorb( - positions, hidden_states, forward_batch, zero_allocator - ) - elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: - return self.forward_absorb_fused_mla_rope( - positions, hidden_states, forward_batch - ) - else: - raise NotImplementedError - - def forward_normal( + def forward_prepare( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, - ) -> torch.Tensor: + zero_allocator: BumpAllocator, + ): + if hidden_states.shape[0] == 0: + assert ( + not self.o_proj.reduce_results + ), "short-circuiting allreduce will lead to hangs" + return hidden_states, None, forward_batch, None + + attn_forward_method = self.dispatch_attn_forward_method(forward_batch) + + if attn_forward_method == AttnForwardMethod.MHA: + inner_state = self.forward_normal_prepare( + positions, hidden_states, forward_batch, zero_allocator + ) + elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV: + inner_state = self.forward_normal_chunked_kv_prepare( + positions, hidden_states, forward_batch, zero_allocator + ) + elif attn_forward_method == AttnForwardMethod.MLA: + inner_state = self.forward_absorb_prepare( + positions, hidden_states, forward_batch, zero_allocator + ) + elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: + inner_state = self.forward_absorb_fused_mla_rope_prepare( + positions, hidden_states, forward_batch, zero_allocator + ) + else: + raise NotImplementedError + return None, attn_forward_method, forward_batch, inner_state + + def forward_core(self, intermediate_state): + hidden_states, attn_forward_method, forward_batch, inner_state = ( + intermediate_state + ) + if inner_state is None: + return hidden_states + + if attn_forward_method == AttnForwardMethod.MHA: + return self.forward_normal_core(*inner_state) + elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV: + return self.forward_normal_chunked_kv_core(*inner_state) + elif attn_forward_method == AttnForwardMethod.MLA: + return self.forward_absorb_core(*inner_state) + elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE: + return self.forward_absorb_fused_mla_rope_core(*inner_state) + else: + raise NotImplementedError + + def forward_normal_prepare( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + zero_allocator: BumpAllocator, + ): if self.q_lora_rank is not None: q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 @@ -749,18 +799,22 @@ class DeepseekV2AttentionMLA(nn.Module): forward_batch.token_to_kv_pool.set_kv_buffer( self.attn_mha, forward_batch.out_cache_loc, latent_cache, None ) + + return q, k, v, forward_batch + + def forward_normal_core(self, q, k, v, forward_batch): attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim) output, _ = self.o_proj(attn_output) return output - def forward_absorb( + def forward_absorb_prepare( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, - ) -> torch.Tensor: + ): from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode if self.q_lora_rank is not None: @@ -829,6 +883,11 @@ class DeepseekV2AttentionMLA(nn.Module): q_nope_out = q_nope_out.transpose(0, 1) q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) + return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator + + def forward_absorb_core( + self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator + ): if self.attention_backend == "fa3" or self.attention_backend == "flashinfer": attn_output = self.attn_mqa( q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe @@ -881,13 +940,13 @@ class DeepseekV2AttentionMLA(nn.Module): return output - def forward_absorb_fused_mla_rope( + def forward_absorb_fused_mla_rope_prepare( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, zero_allocator: BumpAllocator, - ) -> torch.Tensor: + ): enable_rope_fusion = ( os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1" ) @@ -976,6 +1035,44 @@ class DeepseekV2AttentionMLA(nn.Module): ) val_cache_buf = key_cache_buf[..., : self.kv_lora_rank] + return ( + q_input, + key_cache_buf, + val_cache_buf, + attn_output, + kv_indptr, + kv_indices, + k_pe_output, + cos_sin_cache, + positions, + attn_logits, + num_kv_split, + sm_scale, + enable_rope_fusion, + k_input, + forward_batch, + zero_allocator, + ) + + def forward_absorb_fused_mla_rope_core( + self, + q_input, + key_cache_buf, + val_cache_buf, + attn_output, + kv_indptr, + kv_indices, + k_pe_output, + cos_sin_cache, + positions, + attn_logits, + num_kv_split, + sm_scale, + enable_rope_fusion, + k_input, + forward_batch, + zero_allocator, + ): decode_attention_fwd_grouped_rope( q_input, key_cache_buf, @@ -1082,12 +1179,13 @@ class DeepseekV2AttentionMLA(nn.Module): return accum_output - def forward_normal_chunked_kv( + def forward_normal_chunked_kv_prepare( self, positions: torch.Tensor, hidden_states: torch.Tensor, forward_batch: ForwardBatch, - ) -> torch.Tensor: + zero_allocator: BumpAllocator, + ): # In normal mha, the k and v tensors will become overly large when the prefix length is long. # To avoid this, we split the kv cache into chunks and process them one after another. # Since mha is compute friendly, the for loop induced here will not introduce significant overhead. @@ -1130,6 +1228,9 @@ class DeepseekV2AttentionMLA(nn.Module): self.attn_mha, forward_batch.out_cache_loc, latent_cache, None ) + return q, k, v, forward_batch + + def forward_normal_chunked_kv_core(self, q, k, v, forward_batch): # Do mha for extended part without prefix forward_batch.set_attn_attend_prefix_cache(False) attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False) @@ -1283,14 +1384,6 @@ class DeepseekV2DecoderLayer(nn.Module): ) ) - def op_attn(self, state): - state.hidden_states_after_attn = self.self_attn( - positions=state.positions, - hidden_states=state.pop("hidden_states_after_comm_pre_attn"), - forward_batch=state.forward_batch, - zero_allocator=state.zero_allocator, - ) - def op_comm_prepare_mlp(self, state): state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = ( self.layer_communicator.prepare_mlp( diff --git a/python/sglang/srt/operations_strategy.py b/python/sglang/srt/operations_strategy.py index c30a53ac9..be0577ce2 100644 --- a/python/sglang/srt/operations_strategy.py +++ b/python/sglang/srt/operations_strategy.py @@ -7,7 +7,8 @@ def compute_layer_operations( if not layer.is_layer_sparse: return [ layer.op_comm_prepare_attn, - layer.op_attn, + layer.self_attn.op_prepare, + layer.self_attn.op_core, layer.op_comm_prepare_mlp, layer.op_mlp, layer.op_comm_postprocess_layer, @@ -16,7 +17,8 @@ def compute_layer_operations( # Will add TBO operation orders here return [ layer.op_comm_prepare_attn, - layer.op_attn, + layer.self_attn.op_prepare, + layer.self_attn.op_core, layer.op_comm_prepare_mlp, layer.mlp.op_gate, layer.mlp.op_shared_experts,