From 3f87f83116ccda4738b98db7a6dee4cfec3c78e4 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Tue, 22 Apr 2025 20:35:08 -0700 Subject: [PATCH] Fuse q_a_proj and kv_a_proj (#5619) --- python/sglang/srt/models/deepseek_v2.py | 103 ++++++++++++++++++------ 1 file changed, 78 insertions(+), 25 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 8fab2c488..3d230c326 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -443,12 +443,12 @@ class DeepseekV2AttentionMLA(nn.Module): # For tensor parallel attention if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear( + self.fused_qkv_a_proj_with_mqa = ReplicatedLinear( self.hidden_size, - self.q_lora_rank, + self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, - prefix=add_prefix("q_a_proj", prefix), + prefix=add_prefix("fused_qkv_a_proj_with_mqa", prefix), ) self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps) self.q_b_proj = ColumnParallelLinear( @@ -470,6 +470,14 @@ class DeepseekV2AttentionMLA(nn.Module): tp_rank=attn_tp_rank, tp_size=attn_tp_size, ) + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=add_prefix("kv_a_proj_with_mqa", prefix), + ) + self.kv_b_proj = ColumnParallelLinear( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), @@ -490,14 +498,6 @@ class DeepseekV2AttentionMLA(nn.Module): tp_rank=attn_tp_rank, tp_size=attn_tp_size, ) - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=add_prefix("kv_a_proj_with_mqa", prefix), - ) self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps) if rope_scaling: @@ -656,15 +656,18 @@ class DeepseekV2AttentionMLA(nn.Module): forward_batch: ForwardBatch, ) -> torch.Tensor: if self.q_lora_rank is not None: - q = self.q_a_proj(hidden_states)[0] + q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim ) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] + _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) kv_a = self.kv_a_layernorm(kv_a.contiguous()) @@ -699,13 +702,16 @@ class DeepseekV2AttentionMLA(nn.Module): zero_allocator: BumpAllocator, ) -> torch.Tensor: if self.q_lora_rank is not None: - q = self.q_a_proj(hidden_states)[0] + q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim ) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) if self.use_deep_gemm_bmm: @@ -744,7 +750,6 @@ class DeepseekV2AttentionMLA(nn.Module): q_nope_out = q_nope_out.transpose(0, 1) - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] k_nope = latent_cache[..., : self.kv_lora_rank] k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1) k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1) @@ -819,13 +824,16 @@ class DeepseekV2AttentionMLA(nn.Module): q_len, self.num_local_heads, self.kv_lora_rank + self.qk_rope_head_dim ) if self.q_lora_rank is not None: - q = self.q_a_proj(hidden_states)[0] + q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim ) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) if self.w_kc.dtype == torch.float8_e4m3fnuz: @@ -846,8 +854,6 @@ class DeepseekV2AttentionMLA(nn.Module): else: q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc) q_input[..., : self.kv_lora_rank] = q_nope_out.transpose(0, 1) - - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] v_input = latent_cache[..., : self.kv_lora_rank] v_input = self.kv_a_layernorm(v_input.contiguous()).unsqueeze(1) k_input = latent_cache.unsqueeze(1) @@ -1018,15 +1024,17 @@ class DeepseekV2AttentionMLA(nn.Module): # First do normal mha forward to get output for extended part if self.q_lora_rank is not None: - q = self.q_a_proj(hidden_states)[0] + q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1 + ) q = self.q_a_layernorm(q) q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim) else: q = self.q_proj(hidden_states)[0].view( -1, self.num_local_heads, self.qk_head_dim ) + latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] _, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0] kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) latent_cache = latent_cache.unsqueeze(1) kv_a = self.kv_a_layernorm(kv_a.contiguous()) @@ -1668,6 +1676,12 @@ class DeepseekV2ForCausalLM(nn.Module): num_experts=self.config.n_routed_experts + self.n_share_experts_fusion, ) + # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None + fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( + self.config.q_lora_rank is not None + ) + cached_a_proj = {} if fuse_qkv_a_proj else None + params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: # TODO(HandH1998): Modify it when nextn is supported. @@ -1723,11 +1737,50 @@ class DeepseekV2ForCausalLM(nn.Module): if name.endswith(".bias") and name not in params_dict: continue - param = params_dict[name] - weight_loader = getattr( - param, "weight_loader", default_weight_loader - ) - weight_loader(param, loaded_weight) + if fuse_qkv_a_proj and ( + "q_a_proj" in name or "kv_a_proj_with_mqa" in name + ): + cached_a_proj[name] = loaded_weight + q_a_proj_name = ( + name + if "q_a_proj" in name + else name.replace("kv_a_proj_with_mqa", "q_a_proj") + ) + kv_a_proj_name = ( + name + if "kv_a_proj_with_mqa" in name + else name.replace("q_a_proj", "kv_a_proj_with_mqa") + ) + + # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter + if ( + q_a_proj_name in cached_a_proj + and kv_a_proj_name in cached_a_proj + ): + + q_a_proj_weight = cached_a_proj[q_a_proj_name] + kv_a_proj_weight = cached_a_proj[kv_a_proj_name] + fused_weight = torch.cat( + [q_a_proj_weight, kv_a_proj_weight], dim=0 + ) + + param_name = name.replace( + "q_a_proj", "fused_qkv_a_proj_with_mqa" + ) + param = params_dict[param_name] + + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, fused_weight) + cached_a_proj.pop(q_a_proj_name) + cached_a_proj.pop(kv_a_proj_name) + else: + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) self.post_load_weights()