From 866f5e7283ac954bab0b04d189b1a5599dfed596 Mon Sep 17 00:00:00 2001 From: Ruri <33858552+zhoux77899@users.noreply.github.com> Date: Sat, 11 Oct 2025 09:24:02 +0800 Subject: [PATCH] [Bugfix] Fix weight prefetching `AssertionError` in W8A8 MTP scene (#3361) ### What this PR does / why we need it? - Fix `AssertionError` of `weight_prefetch_method` in W8A8 MTP scene - Remove hard-code key (https://github.com/vllm-project/vllm-ascend/pull/3146#discussion_r2416644010) ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? `weight_prefetch_method is None` (tested on DeepSeek-R1-w8a8mix_MTP) - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: zhoux77899 --- vllm_ascend/ops/weight_prefetch.py | 18 +++++++++++++----- vllm_ascend/quantization/w8a8.py | 26 ++++++++++++-------------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/vllm_ascend/ops/weight_prefetch.py b/vllm_ascend/ops/weight_prefetch.py index a6004c5..080d56b 100644 --- a/vllm_ascend/ops/weight_prefetch.py +++ b/vllm_ascend/ops/weight_prefetch.py @@ -4,6 +4,8 @@ import torch import torch_npu from vllm_ascend.ascend_config import WeightPrefetchConfig +from vllm_ascend.ops.linear import (AscendQKVParallelLinear, + AscendRowParallelLinear) SUPPORTED_MODULES = ["attn", "mlp", "moe"] @@ -13,6 +15,7 @@ class ModuleWeightPrefetchConfig: module_name: str enable: bool = False prefetch_ratio: dict = field(default_factory=dict) + linear_prefix_map: dict = field(default_factory=dict) def __post_init__(self) -> None: self.prefetch_ratio = { @@ -38,14 +41,19 @@ class WeightPrefetchMethod: module_name="attn", enable=weight_prefetch_config.enabled, prefetch_ratio=weight_prefetch_config.prefetch_ratio.get( - "attn", {})) + "attn", {}), + linear_prefix_map={ + AscendQKVParallelLinear.__name__: "qkv", + AscendRowParallelLinear.__name__: "o", + }) def maybe_prefetch_attn_weight_preprocess( - self, prefix: str, weight: torch.Tensor, + self, layer_cls_name: str, weight: torch.Tensor, start_flag: torch.Tensor) -> None: - if not self.attn.enable: + if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: return + prefix = self.attn.linear_prefix_map.get(layer_cls_name, "") weight_size = weight.data.element_size() * weight.data.numel( ) * self.attn.prefetch_ratio.get(prefix, 0) @@ -54,8 +62,8 @@ class WeightPrefetchMethod: max_weight_size=int(weight_size)) def maybe_prefetch_attn_weight_postprocess( - self, stop_flag: torch.Tensor) -> None: - if not self.attn.enable: + self, layer_cls_name: str, stop_flag: torch.Tensor) -> None: + if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map: return torch.ops.vllm.prefetch_postprocess(stop_flag) diff --git a/vllm_ascend/quantization/w8a8.py b/vllm_ascend/quantization/w8a8.py index 433dbab..fb4c5a4 100644 --- a/vllm_ascend/quantization/w8a8.py +++ b/vllm_ascend/quantization/w8a8.py @@ -98,31 +98,29 @@ class AscendW8A8LinearMethod: tp_rank: Optional[int] = 0, ) -> torch.Tensor: if x.dtype != torch.int8: - attn_weight_map = { - "AscendQKVParallelLinear": "qkv", - "AscendRowParallelLinear": "o", - } layer_cls_name = layer.__class__.__name__ weight_prefetch_method = get_forward_context( ).weight_prefetch_method - assert weight_prefetch_method is not None - # prefetch_qkvo_proj.weight preprocess - weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( - prefix=attn_weight_map.get(layer_cls_name, ""), - weight=layer.weight, - start_flag=x, - ) + # prefetch qkvo_proj.weight preprocess + if weight_prefetch_method: + weight_prefetch_method.maybe_prefetch_attn_weight_preprocess( + layer_cls_name=layer_cls_name, + weight=layer.weight, + start_flag=x, + ) # quant x = quant_per_tensor( x, layer.aclnn_input_scale_reciprocal, layer.aclnn_input_offset, ) - # prefetch_qkvo_proj.weight postprocess - if layer_cls_name in attn_weight_map.keys(): + # prefetch qkvo_proj.weight postprocess + if weight_prefetch_method: weight_prefetch_method.maybe_prefetch_attn_weight_postprocess( - x) + layer_cls_name=layer_cls_name, + stop_flag=x, + ) quant_bias = layer.quant_bias if tp_rank == 0 else None if is_310p():