Files
xc-llm-ascend/vllm_ascend/ops/weight_prefetch.py
huangdong2022 3a53bbc508 [Feat]Qwen3 Moe supports npu_add_rms_norm_quant op by default, update op with bias, resolve conflict with weight prefetch (#3465)
### What this PR does / why we need it?
1.qwen3 moe uses add_rms_norm_quant op instead of 'add_rms_norm op and
quant op' during quantization scene.
2.torch_npu.add_rms_norm_quant op fixed accuracy while model weights is
quantized by anti_method m4, m4 quantization is asymmetric outlier
suppression method, it will generate none-zero norm bias,
add_rms_norm_quant op updated to add this parameter to calculate.
3. add torch-npu check

### Does this PR introduce _any_ user-facing change?
new feature works if torch_npu version >= torch_npu-2.7.1.dev20250919

### How was this patch tested?
1.no special parameters to set, no new envs to set. new feature works if
torch_npu version >= torch_npu-2.7.1.dev20250919
2.use qwen3 moe quantization model to test ,such as
Qwen3-235B-A22B-W8A8, Qwen3-30B-A3B-W8A8,
Qwen3-235B-A22B-Instruct-2507-m4 (anti_method m4)

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: h30027576 <huangdong51@huawei.com>
2025-10-17 09:30:51 +08:00

115 lines
4.3 KiB
Python

from dataclasses import dataclass, field
import torch
import torch_npu
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import WeightPrefetchConfig
from vllm_ascend.ops.linear import (AscendQKVParallelLinear,
AscendRowParallelLinear)
from vllm_ascend.utils import version_check
SUPPORTED_MODULES = ["attn", "mlp", "moe"]
MOE_PREFETCH_TOKEN_THRESHOLD = 96
@dataclass
class ModuleWeightPrefetchConfig:
module_name: str
enable: bool = False
is_active_this_forward: bool = False
prefetch_ratio: dict = field(default_factory=dict)
linear_prefix_map: dict = field(default_factory=dict)
def __post_init__(self) -> None:
self.prefetch_ratio = {
prefix: ratio
for prefix, ratio in self.prefetch_ratio.items() if 0 <= ratio <= 1
}
assert self.module_name in SUPPORTED_MODULES, (
f"Invalid module name {self.module_name}, should be one of {SUPPORTED_MODULES}"
)
if self.module_name in SUPPORTED_MODULES:
self.enable = self.enable and any(self.prefetch_ratio.values()) > 0
class WeightPrefetchMethod:
"""
Unified weight prefetch method.
"""
def __init__(self, weight_prefetch_config: WeightPrefetchConfig) -> None:
self.attn = ModuleWeightPrefetchConfig(
module_name="attn",
enable=weight_prefetch_config.enabled,
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
"attn", {}),
linear_prefix_map={
AscendQKVParallelLinear.__name__: "qkv",
AscendRowParallelLinear.__name__: "o",
})
self.moe = ModuleWeightPrefetchConfig(
module_name="moe",
enable=weight_prefetch_config.enabled,
prefetch_ratio=weight_prefetch_config.prefetch_ratio.get(
"moe", {}))
def maybe_prefetch_attn_weight_preprocess(
self, layer_cls_name: str, weight: torch.Tensor,
start_flag: torch.Tensor) -> None:
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
return
prefix = self.attn.linear_prefix_map.get(layer_cls_name, "")
weight_size = weight.data.element_size() * weight.data.numel(
) * self.attn.prefetch_ratio.get(prefix, 0)
torch.ops.vllm.prefetch_preprocess(weight=weight,
start_flag=start_flag,
max_weight_size=int(weight_size))
def maybe_prefetch_attn_weight_postprocess(
self, layer_cls_name: str, stop_flag: torch.Tensor) -> None:
if not self.attn.enable or layer_cls_name not in self.attn.linear_prefix_map:
return
torch.ops.vllm.prefetch_postprocess(stop_flag)
def maybe_prefetch_moe_weight_preprocess(self, hidden_states, prefix):
self.moe.is_active_this_forward = hidden_states.shape[
0] >= MOE_PREFETCH_TOKEN_THRESHOLD if self.moe.enable else False
if not self.moe.is_active_this_forward:
return
forward_context = get_forward_context()
if not version_check():
forward_context.layer_idx += 1
weight = forward_context.model_instance.model.layers[
forward_context.layer_idx - 1].mlp.experts.w13_weight
weight_size = weight.data.element_size() * weight.data.numel(
) * self.moe.prefetch_ratio.get(prefix, 0)
torch.ops.vllm.prefetch_preprocess(weight=weight,
start_flag=None,
max_weight_size=int(weight_size))
def maybe_prefetch_moe_weight_postprocess(self, stop_flag: torch.Tensor):
if not self.moe.is_active_this_forward:
return
torch.ops.vllm.prefetch_postprocess(stop_flag)
def maybe_npu_prefetch(inputs: torch.Tensor,
dependency: torch.Tensor,
max_size: int = 0,
offset: int = 0,
*,
enabled: bool = True) -> None:
if not enabled:
return
input_size = inputs.element_size() * inputs.numel()
if max_size <= 0 or max_size > input_size:
max_size = input_size
torch_npu.npu_prefetch(inputs, dependency, max_size, offset)