Transpose mla weight offline (#1261)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -417,12 +417,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
v_head_dim=self.kv_lora_rank,
|
v_head_dim=self.kv_lora_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
kv_b_proj = self.kv_b_proj
|
self.w_kc = None
|
||||||
w_kc, w_vc = kv_b_proj.weight.unflatten(
|
self.w_vc = None
|
||||||
0, (-1, qk_nope_head_dim + v_head_dim)
|
|
||||||
).split([qk_nope_head_dim, v_head_dim], dim=1)
|
|
||||||
self.w_kc = w_kc
|
|
||||||
self.w_vc = w_vc
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -464,7 +460,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
)
|
)
|
||||||
torch.bmm(
|
torch.bmm(
|
||||||
attn_output.transpose(0, 1),
|
attn_output.transpose(0, 1),
|
||||||
self.w_vc.transpose(1, 2).contiguous(),
|
self.w_vc,
|
||||||
out=attn_bmm_output.transpose(0, 1),
|
out=attn_bmm_output.transpose(0, 1),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -715,5 +711,15 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
)
|
)
|
||||||
weight_loader(param, loaded_weight)
|
weight_loader(param, loaded_weight)
|
||||||
|
|
||||||
|
if global_server_args_dict["enable_mla"]:
|
||||||
|
for layer_id in range(self.config.num_hidden_layers):
|
||||||
|
self_attn = self.model.layers[layer_id].self_attn
|
||||||
|
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
|
||||||
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
||||||
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
||||||
|
self_attn.w_kc = w_kc.contiguous()
|
||||||
|
self_attn.w_vc = w_vc.transpose(1, 2).contiguous()
|
||||||
|
del self_attn.kv_b_proj
|
||||||
|
|
||||||
|
|
||||||
EntryClass = DeepseekV2ForCausalLM
|
EntryClass = DeepseekV2ForCausalLM
|
||||||
|
|||||||
Reference in New Issue
Block a user