[Bugfix] Fix weight prefetching AssertionError in W8A8 MTP scene (#3361)
### What this PR does / why we need it? - Fix `AssertionError` of `weight_prefetch_method` in W8A8 MTP scene - Remove hard-code key (https://github.com/vllm-project/vllm-ascend/pull/3146#discussion_r2416644010) ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? `weight_prefetch_method is None` (tested on DeepSeek-R1-w8a8mix_MTP) - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
This commit is contained in:
@@ -98,31 +98,29 @@ class AscendW8A8LinearMethod:
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
if x.dtype != torch.int8:
|
||||
attn_weight_map = {
|
||||
"AscendQKVParallelLinear": "qkv",
|
||||
"AscendRowParallelLinear": "o",
|
||||
}
|
||||
layer_cls_name = layer.__class__.__name__
|
||||
weight_prefetch_method = get_forward_context(
|
||||
).weight_prefetch_method
|
||||
assert weight_prefetch_method is not None
|
||||
|
||||
# prefetch_qkvo_proj.weight preprocess
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
||||
prefix=attn_weight_map.get(layer_cls_name, ""),
|
||||
weight=layer.weight,
|
||||
start_flag=x,
|
||||
)
|
||||
# prefetch qkvo_proj.weight preprocess
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
||||
layer_cls_name=layer_cls_name,
|
||||
weight=layer.weight,
|
||||
start_flag=x,
|
||||
)
|
||||
# quant
|
||||
x = quant_per_tensor(
|
||||
x,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
# prefetch_qkvo_proj.weight postprocess
|
||||
if layer_cls_name in attn_weight_map.keys():
|
||||
# prefetch qkvo_proj.weight postprocess
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
||||
x)
|
||||
layer_cls_name=layer_cls_name,
|
||||
stop_flag=x,
|
||||
)
|
||||
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
if is_310p():
|
||||
|
||||
Reference in New Issue
Block a user