[Bugfix] Fix weight prefetching AssertionError in W8A8 MTP scene (#3361)
### What this PR does / why we need it? - Fix `AssertionError` of `weight_prefetch_method` in W8A8 MTP scene - Remove hard-code key (https://github.com/vllm-project/vllm-ascend/pull/3146#discussion_r2416644010) ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? `weight_prefetch_method is None` (tested on DeepSeek-R1-w8a8mix_MTP) - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: zhoux77899 <zhouxiang100@huawei.com>
This commit is contained in:
@@ -4,6 +4,8 @@ import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.ascend_config import WeightPrefetchConfig
|
||||
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
|
||||
AscendRowParallelLinear)
|
||||
|
||||
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
|
||||
|
||||
@@ -13,6 +15,7 @@ class ModuleWeightPrefetchConfig:
|
||||
module_name: str
|
||||
enable: bool = False
|
||||
prefetch_ratio: dict = field(default_factory=dict)
|
||||
linear_prefix_map: dict = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self.prefetch_ratio = {
|
||||
@@ -38,14 +41,19 @@ class WeightPrefetchMethod:
|
||||
module_name="attn",
|
||||
enable=weight_prefetch_config.enabled,
|
||||
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
|
||||
"attn", {}))
|
||||
"attn", {}),
|
||||
linear_prefix_map={
|
||||
AscendQKVParallelLinear.__name__: "qkv",
|
||||
AscendRowParallelLinear.__name__: "o",
|
||||
})
|
||||
|
||||
def maybe_prefetch_attn_weight_preprocess(
|
||||
self, prefix: str, weight: torch.Tensor,
|
||||
self, layer_cls_name: str, weight: torch.Tensor,
|
||||
start_flag: torch.Tensor) -> None:
|
||||
if not self.attn.enable:
|
||||
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
|
||||
return
|
||||
|
||||
prefix = self.attn.linear_prefix_map.get(layer_cls_name, "")
|
||||
weight_size = weight.data.element_size() * weight.data.numel(
|
||||
) * self.attn.prefetch_ratio.get(prefix, 0)
|
||||
|
||||
@@ -54,8 +62,8 @@ class WeightPrefetchMethod:
|
||||
max_weight_size=int(weight_size))
|
||||
|
||||
def maybe_prefetch_attn_weight_postprocess(
|
||||
self, stop_flag: torch.Tensor) -> None:
|
||||
if not self.attn.enable:
|
||||
self, layer_cls_name: str, stop_flag: torch.Tensor) -> None:
|
||||
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
|
||||
return
|
||||
|
||||
torch.ops.vllm.prefetch_postprocess(stop_flag)
|
||||
|
||||
@@ -98,31 +98,29 @@ class AscendW8A8LinearMethod:
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
if x.dtype != torch.int8:
|
||||
attn_weight_map = {
|
||||
"AscendQKVParallelLinear": "qkv",
|
||||
"AscendRowParallelLinear": "o",
|
||||
}
|
||||
layer_cls_name = layer.__class__.__name__
|
||||
weight_prefetch_method = get_forward_context(
|
||||
).weight_prefetch_method
|
||||
assert weight_prefetch_method is not None
|
||||
|
||||
# prefetch_qkvo_proj.weight preprocess
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
||||
prefix=attn_weight_map.get(layer_cls_name, ""),
|
||||
weight=layer.weight,
|
||||
start_flag=x,
|
||||
)
|
||||
# prefetch qkvo_proj.weight preprocess
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_preprocess(
|
||||
layer_cls_name=layer_cls_name,
|
||||
weight=layer.weight,
|
||||
start_flag=x,
|
||||
)
|
||||
# quant
|
||||
x = quant_per_tensor(
|
||||
x,
|
||||
layer.aclnn_input_scale_reciprocal,
|
||||
layer.aclnn_input_offset,
|
||||
)
|
||||
# prefetch_qkvo_proj.weight postprocess
|
||||
if layer_cls_name in attn_weight_map.keys():
|
||||
# prefetch qkvo_proj.weight postprocess
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_attn_weight_postprocess(
|
||||
x)
|
||||
layer_cls_name=layer_cls_name,
|
||||
stop_flag=x,
|
||||
)
|
||||
|
||||
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
||||
if is_310p():
|
||||
|
||||
Reference in New Issue
Block a user