From d8ab60117fd6101bd52697e4e1214153c0c1d9cd Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 3 May 2025 00:26:30 +0800 Subject: [PATCH] Overlap qk norm with two streams (#5977) --- python/sglang/srt/models/deepseek_v2.py | 32 ++++++++++++++++++++----- 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 85695b114..35c19e14b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -421,6 +421,7 @@ class DeepseekV2AttentionMLA(nn.Module): reduce_results: bool = True, layer_id: int = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.layer_id = layer_id @@ -543,6 +544,8 @@ class DeepseekV2AttentionMLA(nn.Module): prefix=add_prefix("attn_mha", prefix), ) + self.alt_stream = alt_stream + self.w_kc = None self.w_vc = None self.w_scale = None @@ -706,14 +709,32 @@ class DeepseekV2AttentionMLA(nn.Module): q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 ) - q = self.q_a_layernorm(q) + k_nope = latent_cache[..., : self.kv_lora_rank] + + # overlap qk norm + if self.alt_stream is not None and torch.cuda.is_current_stream_capturing(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + q = self.q_a_layernorm(q) + with torch.cuda.stream(self.alt_stream): + k_nope = self.kv_a_layernorm(k_nope) + current_stream.wait_stream(self.alt_stream) + else: + q = self.q_a_layernorm(q) + k_nope = self.kv_a_layernorm(k_nope) + + k_nope = k_nope.unsqueeze(1) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim ) latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + k_nope = latent_cache[..., : self.kv_lora_rank] + k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1) + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1) if self.use_deep_gemm_bmm: q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = ( @@ -750,11 +771,6 @@ class DeepseekV2AttentionMLA(nn.Module): q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) q_nope_out = q_nope_out.transpose(0, 1) - - k_nope = latent_cache[..., : self.kv_lora_rank] - k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1) - k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1) - q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe) if self.attention_backend == "fa3": @@ -1104,6 +1120,7 @@ class DeepseekV2DecoderLayer(nn.Module): quant_config: Optional[QuantizationConfig] = None, is_nextn: bool = False, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -1133,6 +1150,7 @@ class DeepseekV2DecoderLayer(nn.Module): layer_id=layer_id, reduce_results=False, prefix=add_prefix("self_attn", prefix), + alt_stream=alt_stream, ) self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn) @@ -1376,6 +1394,7 @@ class DeepseekV2Model(nn.Module): config.hidden_size, enable_tp=not global_server_args_dict["enable_dp_attention"], ) + self.alt_stream = torch.cuda.Stream() self.layers = nn.ModuleList( [ DeepseekV2DecoderLayer( @@ -1383,6 +1402,7 @@ class DeepseekV2Model(nn.Module): layer_id, quant_config=quant_config, prefix=add_prefix(f"layers.{layer_id}", prefix), + alt_stream=self.alt_stream, ) for layer_id in range(config.num_hidden_layers) ]