[Bugfix] Fix weight prefetching AssertionError in W8A8 MTP scene (#3361)

### What this PR does / why we need it?

- Fix `AssertionError` of `weight_prefetch_method` in W8A8 MTP scene
- Remove hard-code key
(https://github.com/vllm-project/vllm-ascend/pull/3146#discussion_r2416644010)

### Does this PR introduce _any_ user-facing change?

None

### How was this patch tested?
`weight_prefetch_method is None` (tested on DeepSeek-R1-w8a8mix_MTP)

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
This commit is contained in:
Ruri
2025-10-11 09:24:02 +08:00
committed by GitHub
parent 8c1a4dedf3
commit 866f5e7283
2 changed files with 25 additions and 19 deletions

View File

@@ -98,31 +98,29 @@ class AscendW8A8LinearMethod:
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
if x.dtype != torch.int8:
attn_weight_map = {
"AscendQKVParallelLinear": "qkv",
"AscendRowParallelLinear": "o",
}
layer_cls_name = layer.__class__.__name__
weight_prefetch_method = get_forward_context(
).weight_prefetch_method
assert weight_prefetch_method is not None
# prefetch_qkvo_proj.weight preprocess
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
prefix=attn_weight_map.get(layer_cls_name, ""),
weight=layer.weight,
start_flag=x,
)
# prefetch qkvo_proj.weight preprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
layer_cls_name=layer_cls_name,
weight=layer.weight,
start_flag=x,
)
# quant
x = quant_per_tensor(
x,
layer.aclnn_input_scale_reciprocal,
layer.aclnn_input_offset,
)
# prefetch_qkvo_proj.weight postprocess
if layer_cls_name in attn_weight_map.keys():
# prefetch qkvo_proj.weight postprocess
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
x)
layer_cls_name=layer_cls_name,
stop_flag=x,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None
if is_310p():