[BugFix] Fix mlapo accuracy problem related with weight processing. (#3857)
This PR fixes a mlapo accuracy problem related with weight processing. Furthermore, modify mlapo related e2e test with quantized deepseek model to make it effective. Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -111,19 +111,3 @@ def test_mtp2_correctness_full_graph(
|
||||
model_name: str,
|
||||
):
|
||||
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"})
|
||||
def test_mtp_correctness_piecewise_graph_with_mlapo_kernel(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_correctness(sampling_config, model_name, 1)
|
||||
|
||||
|
||||
@patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_MLAPO": "1"})
|
||||
def test_mtp_correctness_full_graph_with_mlapo_kernel(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL)
|
||||
|
||||
@@ -676,9 +676,9 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
..., self.q_lora_rank:].contiguous()
|
||||
q_a_proj_wt = self.fused_qkv_a_proj.weight.data[
|
||||
..., :self.q_lora_rank].contiguous()
|
||||
kv_a_proj_wt = kv_a_proj_wt.contiguous()
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim)
|
||||
kv_a_proj_wt = kv_a_proj_wt.contiguous()
|
||||
kv_a_proj_wt = kv_a_proj_wt.t().contiguous()
|
||||
wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1)
|
||||
wd_qkv = wd_qkv.t().contiguous()
|
||||
wd_qkv = transdata(wd_qkv,
|
||||
|
||||
Reference in New Issue
Block a user