[Model] Support Minimax-m2.5 on NPU (#7105)
### What this PR does / why we need it?
Initial version to support minimax-m2.5 on vllm-ascend.
This commit coverting original fp8 weight to a quantilized bf16 to
support Minimax-m2.5 on NPU.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
### Test Report
Self tested precision summary, where the official precision score of
AIME2025 is 86.3
<img width="426" height="84" alt="image"
src="https://github.com/user-attachments/assets/a3ce2452-92fa-4713-962e-862248e0b61a"
/>
---------
Signed-off-by: limuyuan <limuyuan3@huawei.com>
Signed-off-by: SparrowMu <52023119+SparrowMu@users.noreply.github.com>
Co-authored-by: limuyuan <limuyuan3@huawei.com>
This commit is contained in:
@@ -108,6 +108,35 @@
|
||||
# remove this patch once upstream no longer requires these global symbols or
|
||||
# provides a backend-safe initialization path.
|
||||
#
|
||||
# ** 7. File: platform/patch_minimax_m2_config.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.config.model.ModelConfig._verify_quantization`
|
||||
# Why:
|
||||
# MiniMax-M2 fp8 checkpoints on NPU may fail upstream quantization validation.
|
||||
# vllm-ascend needs to disable fp8 quantization and load bf16 dequantized
|
||||
# weights in worker-side patches instead.
|
||||
# How:
|
||||
# Monkey-patch `_verify_quantization` and intercept platform quantization
|
||||
# verification to force `cfg.quantization=None` for MiniMax-M2 fp8 on NPU.
|
||||
# Related PR (if no, explain why):
|
||||
# No, upstream behavior differs across versions and needs discussion.
|
||||
# Future Plan:
|
||||
# Remove this patch once upstream supports MiniMax-M2 fp8 on NPU or provides
|
||||
# a backend-safe validation / override mechanism.
|
||||
#
|
||||
# 2. `vllm.config.model.ModelConfig._verify_cuda_graph`
|
||||
# Why:
|
||||
# For MiniMax-M2 on NPU with ACL graph capture enabled, HCCL op expansion
|
||||
# mode affects graph shape coverage. Users may forget to set it.
|
||||
# How:
|
||||
# If user doesn't set it, set `HCCL_OP_EXPANSION_MODE=AIV` for this model
|
||||
# and log a warning when a different value is detected.
|
||||
# Related PR (if no, explain why):
|
||||
# No, this is an environment-specific tuning knob.
|
||||
# Future Plan:
|
||||
# Remove this patch if upstream provides an official NPU graph-capture
|
||||
# guidance / auto-configuration path for HCCL.
|
||||
#
|
||||
# * Worker Patch:
|
||||
# ===============
|
||||
#
|
||||
@@ -333,7 +362,73 @@
|
||||
# Future Plan:
|
||||
# Remove this patch when vLLM merges the PR.
|
||||
#
|
||||
# ** 17. File: worker/patch_qwen3_5.py**
|
||||
# ** 17. File: worker/patch_minimax_m2.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.minimax_m2.MiniMaxM2MoE.forward`
|
||||
# Why:
|
||||
# In TP mode, MiniMax-M2 MoE needs a backend-aware reduction path to avoid
|
||||
# unnecessary communication / maintain correctness on NPU.
|
||||
# How:
|
||||
# Replace the forward to call `experts.maybe_all_reduce_tensor_model_parallel`
|
||||
# when `tp_size > 1`.
|
||||
# Related PR (if no, explain why):
|
||||
# No, model-specific behavior.
|
||||
# Future Plan:
|
||||
# Move this behavior upstream once a generic MoE reduce hook exists.
|
||||
#
|
||||
# 2. `vllm.model_executor.models.minimax_m2.MiniMaxM2Attention.__init__`
|
||||
# Why:
|
||||
# When total kv heads < TP world size, kv head replication happens and k_norm
|
||||
# weights should be sharded to match the replication layout.
|
||||
# How:
|
||||
# Add `num_kv_head_replicas` and create sharded `k_norm` via
|
||||
# `MiniMaxText01RMSNormTP(..., weight_shard_world_size=total_num_kv_heads, ...)`.
|
||||
# Related PR (if no, explain why):
|
||||
# No, depends on Ascend kernel behavior and TP layout.
|
||||
# Future Plan:
|
||||
# Remove this patch if upstream implements kv-head-aware norm sharding.
|
||||
#
|
||||
# 3. `vllm.model_executor.models.minimax_m2.MiniMaxM2Model.load_weights`
|
||||
# Why:
|
||||
# MiniMax-M2 fp8 checkpoints may store fp8 weights with per-block inverse
|
||||
# scales. On NPU we load bf16 weights by dequantizing at load time.
|
||||
# How:
|
||||
# Inject fp8 dequant helpers and wrap `load_weights` to convert fp8 weight +
|
||||
# `weight_scale_inv` pairs into bf16 blocks before delegating to upstream.
|
||||
# Related PR (if no, explain why):
|
||||
# No, fp8 load format and backend constraints are model/backend specific.
|
||||
# Future Plan:
|
||||
# Remove this patch when upstream supports MiniMax-M2 fp8 loading on NPU.
|
||||
#
|
||||
# ** 18. File: worker/patch_minimax_m2_linear_attn.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.__init__`
|
||||
# `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.weight_loader`
|
||||
# Why:
|
||||
# MiniMax-M2 linear attention RMSNorm needs weight sharding that can follow
|
||||
# TP layout (and sometimes kv-head replication) on NPU.
|
||||
# How:
|
||||
# Override `__init__` to parameterize weight shard world/rank and install a
|
||||
# sharded `weight_loader` implementation.
|
||||
# Related PR (if no, explain why):
|
||||
# No, upstream API surface differs across versions.
|
||||
# Future Plan:
|
||||
# Remove this patch when upstream exposes stable sharding hooks for this layer.
|
||||
#
|
||||
# 2. `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.forward_qk`
|
||||
# (or older `_normalize_qk`)
|
||||
# Why:
|
||||
# q/k norm for linear attention is performance-sensitive. On NPU, a fused
|
||||
# rms_norm kernel is faster and TP needs a global rstd correction.
|
||||
# How:
|
||||
# Replace q/k normalization with NPU rms_norm fast path and TP-global rstd
|
||||
# correction; fall back to upstream implementation on non-NPU.
|
||||
# Related PR (if no, explain why):
|
||||
# No, backend-specific optimization.
|
||||
# Future Plan:
|
||||
# Remove this patch when upstream adds a backend dispatch path for q/k norm.
|
||||
#
|
||||
# ** 19. File: worker/patch_qwen3_5.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.qwen3_5.Qwen3_5GatedDeltaNet._forward_core`
|
||||
# Why:
|
||||
|
||||
@@ -19,6 +19,7 @@ import os
|
||||
import vllm_ascend.patch.platform.patch_distributed # noqa
|
||||
import vllm_ascend.patch.platform.patch_fusion_matcher_compat_ops # noqa
|
||||
import vllm_ascend.patch.platform.patch_mamba_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_minimax_m2_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
||||
|
||||
if os.getenv("DYNAMIC_EPLB", "false").lower() in ("true", "1") or os.getenv("EXPERT_MAP_RECORD", "false") == "true":
|
||||
|
||||
138
vllm_ascend/patch/platform/patch_minimax_m2_config.py
Normal file
138
vllm_ascend/patch/platform/patch_minimax_m2_config.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Patch target: vllm/config/model.py
|
||||
# - MiniMax-M2 fp8 checkpoint on NPU: disable fp8 quantization (load bf16
|
||||
# dequantized weights in worker patch) instead of failing validation.
|
||||
# - For ACL graph capture, set HCCL_OP_EXPANSION_MODE=AIV if user didn't set it.
|
||||
#
|
||||
|
||||
import os
|
||||
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_original_verify_quantization = getattr(ModelConfig, "_verify_quantization", None)
|
||||
_original_verify_cuda_graph = getattr(ModelConfig, "_verify_cuda_graph", None)
|
||||
|
||||
_DISABLE_FP8_LOG = (
|
||||
"Detected fp8 MiniMax-M2 checkpoint on NPU. "
|
||||
"Disabling fp8 quantization and loading dequantized bf16 "
|
||||
"weights instead."
|
||||
)
|
||||
|
||||
|
||||
def _get_model_type(cfg: ModelConfig) -> str | None:
|
||||
# vLLM config fields have changed across versions; try multiple sources.
|
||||
model_arch_cfg = getattr(cfg, "model_arch_config", None)
|
||||
if model_arch_cfg is not None:
|
||||
mt = getattr(model_arch_cfg, "model_type", None)
|
||||
if mt:
|
||||
return mt
|
||||
|
||||
hf_text_cfg = getattr(cfg, "hf_text_config", None)
|
||||
if hf_text_cfg is not None:
|
||||
mt = getattr(hf_text_cfg, "model_type", None)
|
||||
if mt:
|
||||
return mt
|
||||
|
||||
hf_cfg = getattr(cfg, "hf_config", None)
|
||||
if hf_cfg is not None:
|
||||
mt = getattr(hf_cfg, "model_type", None)
|
||||
if mt:
|
||||
return mt
|
||||
|
||||
return getattr(cfg, "model_type", None)
|
||||
|
||||
|
||||
def _should_disable_fp8(cfg: ModelConfig, quant_method: str | None) -> bool:
|
||||
return current_platform.device_name == "npu" and _get_model_type(cfg) == "minimax_m2" and quant_method == "fp8"
|
||||
|
||||
|
||||
def _disable_fp8(cfg: ModelConfig, *, log: bool) -> bool:
|
||||
if not _should_disable_fp8(cfg, getattr(cfg, "quantization", None)):
|
||||
return False
|
||||
if log:
|
||||
logger.warning(_DISABLE_FP8_LOG)
|
||||
cfg.quantization = None
|
||||
return True
|
||||
|
||||
|
||||
def _patched_verify_quantization(self: ModelConfig) -> None:
|
||||
"""Inject mid-function behavior for ModelConfig._verify_quantization.
|
||||
|
||||
Upstream validates quantization inside this method via:
|
||||
current_platform.verify_quantization(self.quantization)
|
||||
|
||||
We emulate a mid-function patch without copying upstream code by temporarily
|
||||
overriding current_platform.verify_quantization while the original verifier
|
||||
executes.
|
||||
"""
|
||||
assert _original_verify_quantization is not None
|
||||
|
||||
orig_platform_verify = getattr(current_platform, "verify_quantization", None)
|
||||
|
||||
def _platform_verify_hook(quant_method: str | None) -> None:
|
||||
if _should_disable_fp8(self, quant_method):
|
||||
# This is the effective "middle of _verify_quantization" interception.
|
||||
_disable_fp8(self, log=True)
|
||||
return
|
||||
assert orig_platform_verify is not None
|
||||
return orig_platform_verify(quant_method)
|
||||
|
||||
# Some versions may read self.quantization before calling platform verifier.
|
||||
_disable_fp8(self, log=True)
|
||||
|
||||
try:
|
||||
if orig_platform_verify is not None:
|
||||
current_platform.verify_quantization = _platform_verify_hook
|
||||
return _original_verify_quantization(self)
|
||||
finally:
|
||||
if orig_platform_verify is not None:
|
||||
current_platform.verify_quantization = orig_platform_verify
|
||||
# Ensure fp8 isn't restored by upstream logic.
|
||||
_disable_fp8(self, log=False)
|
||||
|
||||
|
||||
def _patched_verify_cuda_graph(self: ModelConfig) -> None:
|
||||
assert _original_verify_cuda_graph is not None
|
||||
|
||||
if (
|
||||
current_platform.device_name == "npu"
|
||||
and _get_model_type(self) == "minimax_m2"
|
||||
and not getattr(self, "enforce_eager", True)
|
||||
):
|
||||
expansion_mode = os.environ.get("HCCL_OP_EXPANSION_MODE")
|
||||
if expansion_mode is None:
|
||||
os.environ["HCCL_OP_EXPANSION_MODE"] = "AIV"
|
||||
logger.info("Set HCCL_OP_EXPANSION_MODE=AIV for MiniMax-M2 ACL graph capture on NPU.")
|
||||
elif expansion_mode != "AIV":
|
||||
logger.warning(
|
||||
"HCCL_OP_EXPANSION_MODE=%s may reduce ACL graph shape "
|
||||
"coverage for MiniMax-M2 on NPU. Recommended value: AIV.",
|
||||
expansion_mode,
|
||||
)
|
||||
|
||||
return _original_verify_cuda_graph(self)
|
||||
|
||||
|
||||
if _original_verify_quantization is not None:
|
||||
ModelConfig._verify_quantization = _patched_verify_quantization
|
||||
|
||||
if _original_verify_cuda_graph is not None:
|
||||
ModelConfig._verify_cuda_graph = _patched_verify_cuda_graph
|
||||
@@ -30,6 +30,8 @@ import vllm_ascend.patch.platform.patch_sched_yield # noqa
|
||||
import vllm_ascend.patch.worker.patch_unquantized_gemm # noqa
|
||||
import vllm_ascend.patch.worker.patch_bert # noqa
|
||||
import vllm_ascend.patch.worker.patch_distributed # noqa
|
||||
import vllm_ascend.patch.worker.patch_minimax_m2 # noqa
|
||||
import vllm_ascend.patch.worker.patch_minimax_m2_linear_attn # noqa
|
||||
import vllm_ascend.patch.worker.patch_multimodal_merge # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_next # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_next_mtp # noqa
|
||||
|
||||
174
vllm_ascend/patch/worker/patch_minimax_m2.py
Normal file
174
vllm_ascend/patch/worker/patch_minimax_m2.py
Normal file
@@ -0,0 +1,174 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MiniMax-M2 on Ascend: MoE all_reduce, k_norm weight sharding, fp8 load dequant.
|
||||
#
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
|
||||
from vllm.model_executor.models.minimax_m2 import MiniMaxM2Attention, MiniMaxM2Model, MiniMaxM2MoE
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
FP8_DTYPES = tuple(
|
||||
getattr(torch, dtype_name)
|
||||
for dtype_name in (
|
||||
"float8_e4m3fn",
|
||||
"float8_e4m3fnuz",
|
||||
"float8_e5m2",
|
||||
"float8_e5m2fnuz",
|
||||
"float8_e8m0fnu",
|
||||
)
|
||||
if hasattr(torch, dtype_name)
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MiniMaxM2MoE.forward: use maybe_all_reduce_tensor_model_parallel
|
||||
# ---------------------------------------------------------------------------
|
||||
def _patched_moe_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states.to(torch.float32))
|
||||
final_hidden_states = self.experts(hidden_states=hidden_states, router_logits=router_logits)
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
|
||||
MiniMaxM2MoE.forward = _patched_moe_forward
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MiniMaxM2Attention: num_kv_head_replicas and k_norm weight sharding
|
||||
# ---------------------------------------------------------------------------
|
||||
_original_attention_init = MiniMaxM2Attention.__init__
|
||||
|
||||
|
||||
def _patched_attention_init(self, *args, **kwargs) -> None:
|
||||
_original_attention_init(self, *args, **kwargs)
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_kv_head_replicas = max(1, tp_size // self.total_num_kv_heads)
|
||||
if self.total_num_kv_heads < tp_size:
|
||||
rms_norm_eps = getattr(getattr(self, "q_norm", None), "variance_epsilon", 1e-6)
|
||||
self.k_norm = MiniMaxText01RMSNormTP(
|
||||
self.head_dim * self.total_num_kv_heads,
|
||||
eps=rms_norm_eps,
|
||||
weight_shard_world_size=self.total_num_kv_heads,
|
||||
weight_shard_rank=get_tensor_model_parallel_rank() // self.num_kv_head_replicas,
|
||||
)
|
||||
|
||||
|
||||
MiniMaxM2Attention.__init__ = _patched_attention_init
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MiniMaxM2Model: fp8 dequant helpers and load_weights wrapper
|
||||
# ---------------------------------------------------------------------------
|
||||
def _need_dequantize_fp8_weights(self) -> bool:
|
||||
quant_cfg = getattr(self.config, "quantization_config", None)
|
||||
return (
|
||||
isinstance(quant_cfg, dict) and quant_cfg.get("quant_method") == "fp8" and current_platform.device_name == "npu"
|
||||
)
|
||||
|
||||
|
||||
def _dequantize_fp8_block_weight(
|
||||
fp8_weight: torch.Tensor,
|
||||
weight_scale_inv: torch.Tensor,
|
||||
block_size: tuple[int, int],
|
||||
) -> torch.Tensor:
|
||||
block_n, block_k = block_size
|
||||
n, k = fp8_weight.shape
|
||||
n_tiles = (n + block_n - 1) // block_n
|
||||
k_tiles = (k + block_k - 1) // block_k
|
||||
if tuple(weight_scale_inv.shape) != (n_tiles, k_tiles):
|
||||
raise ValueError(
|
||||
"Unexpected fp8 scale shape: "
|
||||
f"weight={tuple(fp8_weight.shape)}, "
|
||||
f"scale={tuple(weight_scale_inv.shape)}, "
|
||||
f"block_size={block_size}"
|
||||
)
|
||||
expanded_scale = weight_scale_inv.repeat_interleave(block_n, dim=0).repeat_interleave(block_k, dim=1)
|
||||
expanded_scale = expanded_scale[:n, :k].to(dtype=torch.bfloat16)
|
||||
return fp8_weight.to(dtype=torch.bfloat16) * expanded_scale
|
||||
|
||||
|
||||
def _fp8_dequant_weight_iter(
|
||||
self: "MiniMaxM2Model",
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
quant_cfg = getattr(self.config, "quantization_config", {})
|
||||
block_cfg = quant_cfg.get("weight_block_size", [128, 128])
|
||||
weight_block_size: tuple[int, int] = (128, 128)
|
||||
if isinstance(block_cfg, list) and len(block_cfg) == 2:
|
||||
weight_block_size = (int(block_cfg[0]), int(block_cfg[1]))
|
||||
|
||||
pending_fp8_weights: dict[str, torch.Tensor] = {}
|
||||
pending_fp8_scales: dict[str, torch.Tensor] = {}
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if name.endswith(".weight_scale_inv"):
|
||||
paired_weight_name = name[: -len("_scale_inv")]
|
||||
pending_weight = pending_fp8_weights.pop(paired_weight_name, None)
|
||||
if pending_weight is None:
|
||||
pending_fp8_scales[name] = loaded_weight
|
||||
continue
|
||||
loaded_weight = self._dequantize_fp8_block_weight(pending_weight, loaded_weight, weight_block_size)
|
||||
name = paired_weight_name
|
||||
elif loaded_weight.dtype in FP8_DTYPES and name.endswith(".weight"):
|
||||
scale_name = f"{name}_scale_inv"
|
||||
pending_scale = pending_fp8_scales.pop(scale_name, None)
|
||||
if pending_scale is None:
|
||||
pending_fp8_weights[name] = loaded_weight
|
||||
continue
|
||||
loaded_weight = self._dequantize_fp8_block_weight(loaded_weight, pending_scale, weight_block_size)
|
||||
yield name, loaded_weight
|
||||
|
||||
if pending_fp8_weights or pending_fp8_scales:
|
||||
raise ValueError(
|
||||
"Unpaired fp8 MiniMax-M2 weight/scale tensors detected: "
|
||||
f"pending_weights={len(pending_fp8_weights)}, "
|
||||
f"pending_scales={len(pending_fp8_scales)}"
|
||||
)
|
||||
|
||||
|
||||
MiniMaxM2Model._need_dequantize_fp8_weights = _need_dequantize_fp8_weights
|
||||
MiniMaxM2Model._dequantize_fp8_block_weight = staticmethod(_dequantize_fp8_block_weight)
|
||||
MiniMaxM2Model._fp8_dequant_weight_iter = _fp8_dequant_weight_iter
|
||||
|
||||
_original_load_weights = MiniMaxM2Model.load_weights
|
||||
|
||||
|
||||
def _patched_load_weights(
|
||||
self: "MiniMaxM2Model",
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
) -> set[str]:
|
||||
if self._need_dequantize_fp8_weights():
|
||||
weights = self._fp8_dequant_weight_iter(weights)
|
||||
return _original_load_weights(self, weights)
|
||||
|
||||
|
||||
MiniMaxM2Model.load_weights = _patched_load_weights
|
||||
145
vllm_ascend/patch/worker/patch_minimax_m2_linear_attn.py
Normal file
145
vllm_ascend/patch/worker/patch_minimax_m2_linear_attn.py
Normal file
@@ -0,0 +1,145 @@
|
||||
#
|
||||
# Copyright (c) 2026 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MiniMax-M2 linear attention: MiniMaxText01RMSNormTP weight sharding and NPU q/k norm path.
|
||||
#
|
||||
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.linear_attn import (
|
||||
CustomOp,
|
||||
MiniMaxText01RMSNormTP,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
_ORIG_QK_METHOD_NAME: str | None = None
|
||||
_original_qk_method = None
|
||||
_qk_is_staticmethod = False
|
||||
|
||||
if hasattr(MiniMaxText01RMSNormTP, "forward_qk"):
|
||||
_ORIG_QK_METHOD_NAME = "forward_qk"
|
||||
_original_qk_method = getattr(MiniMaxText01RMSNormTP, _ORIG_QK_METHOD_NAME)
|
||||
elif hasattr(MiniMaxText01RMSNormTP, "_normalize_qk"):
|
||||
# Older vLLM versions
|
||||
_ORIG_QK_METHOD_NAME = "_normalize_qk"
|
||||
_original_qk_method = getattr(MiniMaxText01RMSNormTP, _ORIG_QK_METHOD_NAME)
|
||||
|
||||
if _ORIG_QK_METHOD_NAME is not None:
|
||||
# Detect whether upstream defined it as a staticmethod (some versions do).
|
||||
_orig_desc = MiniMaxText01RMSNormTP.__dict__.get(_ORIG_QK_METHOD_NAME)
|
||||
_qk_is_staticmethod = isinstance(_orig_desc, staticmethod)
|
||||
|
||||
|
||||
def _patched_qk(
|
||||
q_norm: "MiniMaxText01RMSNormTP",
|
||||
k_norm: "MiniMaxText01RMSNormTP",
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# NPU fast path: kernelized local RMSNorm for q/k, then TP-global rstd correction.
|
||||
if current_platform.device_name == "npu":
|
||||
q, q_inv_rms = torch.ops.npu.npu_rms_norm(q, q_norm.weight, q_norm.variance_epsilon)
|
||||
k, k_inv_rms = torch.ops.npu.npu_rms_norm(k, k_norm.weight, k_norm.variance_epsilon)
|
||||
|
||||
if q_norm.tp_world > 1:
|
||||
q_local_inv_rms = q_inv_rms.to(torch.float32)
|
||||
if q_local_inv_rms.shape[-1] != 1:
|
||||
q_local_inv_rms = q_local_inv_rms.mean(dim=-1, keepdim=True)
|
||||
q_local_var = (q_local_inv_rms.reciprocal().pow(2) - q_norm.variance_epsilon).clamp_min_(0.0)
|
||||
|
||||
k_local_inv_rms = k_inv_rms.to(torch.float32)
|
||||
if k_local_inv_rms.shape[-1] != 1:
|
||||
k_local_inv_rms = k_local_inv_rms.mean(dim=-1, keepdim=True)
|
||||
k_local_var = (k_local_inv_rms.reciprocal().pow(2) - k_norm.variance_epsilon).clamp_min_(0.0)
|
||||
|
||||
qk_var = torch.cat([q_local_var, k_local_var], dim=-1)
|
||||
qk_var = tensor_model_parallel_all_reduce(qk_var) / q_norm.tp_world
|
||||
q_global_var, k_global_var = qk_var.chunk(2, dim=-1)
|
||||
|
||||
q_local_rstd = torch.rsqrt(q_local_var + q_norm.variance_epsilon)
|
||||
k_local_rstd = torch.rsqrt(k_local_var + k_norm.variance_epsilon)
|
||||
q_global_rstd = torch.rsqrt(q_global_var + q_norm.variance_epsilon)
|
||||
k_global_rstd = torch.rsqrt(k_global_var + k_norm.variance_epsilon)
|
||||
|
||||
q = q * (q_global_rstd / q_local_rstd).to(q.dtype)
|
||||
k = k * (k_global_rstd / k_local_rstd).to(k.dtype)
|
||||
|
||||
return q, k
|
||||
|
||||
assert _original_qk_method is not None
|
||||
# We install the patch as a staticmethod below, so prefer the static calling
|
||||
# convention for the original as well.
|
||||
return _original_qk_method(q_norm, k_norm, q, k)
|
||||
|
||||
|
||||
def _patched_weight_loader(
|
||||
param: nn.Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
shard_world_size: int | None = None,
|
||||
shard_rank: int | None = None,
|
||||
) -> None:
|
||||
if shard_world_size is None:
|
||||
shard_world_size = get_tensor_model_parallel_world_size()
|
||||
if shard_rank is None:
|
||||
shard_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = loaded_weight.shape[0] // shard_world_size
|
||||
shard = slice(shard_rank * shard_size, (shard_rank + 1) * shard_size)
|
||||
param.data.copy_(loaded_weight[shard])
|
||||
|
||||
|
||||
def _patched_init(
|
||||
self: "MiniMaxText01RMSNormTP",
|
||||
hidden_size: int,
|
||||
eps: float = 1e-6,
|
||||
*,
|
||||
weight_shard_world_size: int | None = None,
|
||||
weight_shard_rank: int | None = None,
|
||||
) -> None:
|
||||
CustomOp.__init__(self)
|
||||
self.tp_world = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.weight_shard_world = weight_shard_world_size or self.tp_world
|
||||
self.weight_shard_rank = self.tp_rank if weight_shard_rank is None else weight_shard_rank
|
||||
|
||||
if hidden_size % self.weight_shard_world != 0:
|
||||
raise ValueError(
|
||||
"MiniMaxText01RMSNormTP hidden_size must be divisible by "
|
||||
f"weight_shard_world_size, got hidden_size={hidden_size}, "
|
||||
f"weight_shard_world_size={self.weight_shard_world}"
|
||||
)
|
||||
|
||||
self.weight = nn.Parameter(torch.ones(int(hidden_size / self.weight_shard_world)))
|
||||
self.weight.weight_loader = partial(
|
||||
_patched_weight_loader,
|
||||
shard_world_size=self.weight_shard_world,
|
||||
shard_rank=self.weight_shard_rank,
|
||||
)
|
||||
self.variance_epsilon = eps
|
||||
|
||||
|
||||
MiniMaxText01RMSNormTP.__init__ = _patched_init
|
||||
MiniMaxText01RMSNormTP.weight_loader = staticmethod(_patched_weight_loader)
|
||||
|
||||
if _ORIG_QK_METHOD_NAME is not None:
|
||||
# Force staticmethod style, as requested.
|
||||
setattr(MiniMaxText01RMSNormTP, _ORIG_QK_METHOD_NAME, staticmethod(_patched_qk))
|
||||
Reference in New Issue
Block a user