From f414352ae6783dc20dc93e09be00ea62f4438931 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Fri, 30 Aug 2024 14:45:40 +0800 Subject: [PATCH] Transpose mla weight offline (#1261) Co-authored-by: Yineng Zhang --- python/sglang/srt/models/deepseek_v2.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 15ecf4bb6..67d99d512 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -417,12 +417,8 @@ class DeepseekV2AttentionMLA(nn.Module): v_head_dim=self.kv_lora_rank, ) - kv_b_proj = self.kv_b_proj - w_kc, w_vc = kv_b_proj.weight.unflatten( - 0, (-1, qk_nope_head_dim + v_head_dim) - ).split([qk_nope_head_dim, v_head_dim], dim=1) - self.w_kc = w_kc - self.w_vc = w_vc + self.w_kc = None + self.w_vc = None def forward( self, @@ -464,7 +460,7 @@ class DeepseekV2AttentionMLA(nn.Module): ) torch.bmm( attn_output.transpose(0, 1), - self.w_vc.transpose(1, 2).contiguous(), + self.w_vc, out=attn_bmm_output.transpose(0, 1), ) @@ -715,5 +711,15 @@ class DeepseekV2ForCausalLM(nn.Module): ) weight_loader(param, loaded_weight) + if global_server_args_dict["enable_mla"]: + for layer_id in range(self.config.num_hidden_layers): + self_attn = self.model.layers[layer_id].self_attn + w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + self_attn.w_kc = w_kc.contiguous() + self_attn.w_vc = w_vc.transpose(1, 2).contiguous() + del self_attn.kv_b_proj + EntryClass = DeepseekV2ForCausalLM