[Model] Support Minimax-m2.5 on NPU (#7105)
### What this PR does / why we need it?
Initial version to support minimax-m2.5 on vllm-ascend.
This commit coverting original fp8 weight to a quantilized bf16 to
support Minimax-m2.5 on NPU.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
### Test Report
Self tested precision summary, where the official precision score of
AIME2025 is 86.3
<img width="426" height="84" alt="image"
src="https://github.com/user-attachments/assets/a3ce2452-92fa-4713-962e-862248e0b61a"
/>
---------
Signed-off-by: limuyuan <limuyuan3@huawei.com>
Signed-off-by: SparrowMu <52023119+SparrowMu@users.noreply.github.com>
Co-authored-by: limuyuan <limuyuan3@huawei.com>
This commit is contained in:
@@ -108,6 +108,35 @@
|
||||
# remove this patch once upstream no longer requires these global symbols or
|
||||
# provides a backend-safe initialization path.
|
||||
#
|
||||
# ** 7. File: platform/patch_minimax_m2_config.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.config.model.ModelConfig._verify_quantization`
|
||||
# Why:
|
||||
# MiniMax-M2 fp8 checkpoints on NPU may fail upstream quantization validation.
|
||||
# vllm-ascend needs to disable fp8 quantization and load bf16 dequantized
|
||||
# weights in worker-side patches instead.
|
||||
# How:
|
||||
# Monkey-patch `_verify_quantization` and intercept platform quantization
|
||||
# verification to force `cfg.quantization=None` for MiniMax-M2 fp8 on NPU.
|
||||
# Related PR (if no, explain why):
|
||||
# No, upstream behavior differs across versions and needs discussion.
|
||||
# Future Plan:
|
||||
# Remove this patch once upstream supports MiniMax-M2 fp8 on NPU or provides
|
||||
# a backend-safe validation / override mechanism.
|
||||
#
|
||||
# 2. `vllm.config.model.ModelConfig._verify_cuda_graph`
|
||||
# Why:
|
||||
# For MiniMax-M2 on NPU with ACL graph capture enabled, HCCL op expansion
|
||||
# mode affects graph shape coverage. Users may forget to set it.
|
||||
# How:
|
||||
# If user doesn't set it, set `HCCL_OP_EXPANSION_MODE=AIV` for this model
|
||||
# and log a warning when a different value is detected.
|
||||
# Related PR (if no, explain why):
|
||||
# No, this is an environment-specific tuning knob.
|
||||
# Future Plan:
|
||||
# Remove this patch if upstream provides an official NPU graph-capture
|
||||
# guidance / auto-configuration path for HCCL.
|
||||
#
|
||||
# * Worker Patch:
|
||||
# ===============
|
||||
#
|
||||
@@ -333,7 +362,73 @@
|
||||
# Future Plan:
|
||||
# Remove this patch when vLLM merges the PR.
|
||||
#
|
||||
# ** 17. File: worker/patch_qwen3_5.py**
|
||||
# ** 17. File: worker/patch_minimax_m2.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.minimax_m2.MiniMaxM2MoE.forward`
|
||||
# Why:
|
||||
# In TP mode, MiniMax-M2 MoE needs a backend-aware reduction path to avoid
|
||||
# unnecessary communication / maintain correctness on NPU.
|
||||
# How:
|
||||
# Replace the forward to call `experts.maybe_all_reduce_tensor_model_parallel`
|
||||
# when `tp_size > 1`.
|
||||
# Related PR (if no, explain why):
|
||||
# No, model-specific behavior.
|
||||
# Future Plan:
|
||||
# Move this behavior upstream once a generic MoE reduce hook exists.
|
||||
#
|
||||
# 2. `vllm.model_executor.models.minimax_m2.MiniMaxM2Attention.__init__`
|
||||
# Why:
|
||||
# When total kv heads < TP world size, kv head replication happens and k_norm
|
||||
# weights should be sharded to match the replication layout.
|
||||
# How:
|
||||
# Add `num_kv_head_replicas` and create sharded `k_norm` via
|
||||
# `MiniMaxText01RMSNormTP(..., weight_shard_world_size=total_num_kv_heads, ...)`.
|
||||
# Related PR (if no, explain why):
|
||||
# No, depends on Ascend kernel behavior and TP layout.
|
||||
# Future Plan:
|
||||
# Remove this patch if upstream implements kv-head-aware norm sharding.
|
||||
#
|
||||
# 3. `vllm.model_executor.models.minimax_m2.MiniMaxM2Model.load_weights`
|
||||
# Why:
|
||||
# MiniMax-M2 fp8 checkpoints may store fp8 weights with per-block inverse
|
||||
# scales. On NPU we load bf16 weights by dequantizing at load time.
|
||||
# How:
|
||||
# Inject fp8 dequant helpers and wrap `load_weights` to convert fp8 weight +
|
||||
# `weight_scale_inv` pairs into bf16 blocks before delegating to upstream.
|
||||
# Related PR (if no, explain why):
|
||||
# No, fp8 load format and backend constraints are model/backend specific.
|
||||
# Future Plan:
|
||||
# Remove this patch when upstream supports MiniMax-M2 fp8 loading on NPU.
|
||||
#
|
||||
# ** 18. File: worker/patch_minimax_m2_linear_attn.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.__init__`
|
||||
# `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.weight_loader`
|
||||
# Why:
|
||||
# MiniMax-M2 linear attention RMSNorm needs weight sharding that can follow
|
||||
# TP layout (and sometimes kv-head replication) on NPU.
|
||||
# How:
|
||||
# Override `__init__` to parameterize weight shard world/rank and install a
|
||||
# sharded `weight_loader` implementation.
|
||||
# Related PR (if no, explain why):
|
||||
# No, upstream API surface differs across versions.
|
||||
# Future Plan:
|
||||
# Remove this patch when upstream exposes stable sharding hooks for this layer.
|
||||
#
|
||||
# 2. `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.forward_qk`
|
||||
# (or older `_normalize_qk`)
|
||||
# Why:
|
||||
# q/k norm for linear attention is performance-sensitive. On NPU, a fused
|
||||
# rms_norm kernel is faster and TP needs a global rstd correction.
|
||||
# How:
|
||||
# Replace q/k normalization with NPU rms_norm fast path and TP-global rstd
|
||||
# correction; fall back to upstream implementation on non-NPU.
|
||||
# Related PR (if no, explain why):
|
||||
# No, backend-specific optimization.
|
||||
# Future Plan:
|
||||
# Remove this patch when upstream adds a backend dispatch path for q/k norm.
|
||||
#
|
||||
# ** 19. File: worker/patch_qwen3_5.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.qwen3_5.Qwen3_5GatedDeltaNet._forward_core`
|
||||
# Why:
|
||||
|
||||
Reference in New Issue
Block a user