[v0.18.0] Apply Eagle3 to MiniMax-M2.5 (#7619) (#7714)

### What this PR does / why we need it?
Apply Eagle3 to MiniMax-M2.5 to increase model performance This will be
discard after Eagle3 weight for MiniMax-M2.5 releases and code change
accepted by official repo
https://github.com/vllm-project/vllm/pull/37512/changes
backport: #7619

- vLLM version: v0.18.0
- vLLM main:
ed359c497a

Signed-off-by: limuyuan <limuyuan3@huawei.com>
Co-authored-by: limuyuan <limuyuan3@huawei.com>
This commit is contained in:
SparrowMu
2026-03-27 18:33:29 +08:00
committed by GitHub
parent 60e88d9541
commit 6fbd0049df
3 changed files with 312 additions and 24 deletions

View File

@@ -137,6 +137,38 @@
# Remove this patch if upstream provides an official NPU graph-capture
# guidance / auto-configuration path for HCCL.
#
# 3. `vllm.config.speculative.SpeculativeConfig._verify_args`
# Why:
# Upstream vLLM's eagle3/extract_hidden_states restricts target model types
# via a whitelist. MiniMax-M2 should be allowed once the worker-side model
# can emit auxiliary hidden states.
# How
# Monkey-patch `_verify_args` to bypass only the whitelist ValueError for
# MiniMax model_type when method is eagle3/extract_hidden_states.
# SpeculativeConfig is a Pydantic dataclass (`@config`); init validation calls
# `__pydantic_decorators__.model_validators["_verify_args"].func`, so that
# `Decorator.func` must be replaced (not only `SpeculativeConfig._verify_args`),
# then `rebuild_dataclass(SpeculativeConfig, force=True)`.
# If `VllmConfig` was imported earlier, also `rebuild_dataclass(VllmConfig, ...)`
# so nested `speculative_config` validation does not use a stale schema.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/37512
# Future Plan:
# Remove this patch once upstream whitelist includes MiniMax.
#
# 4. `vllm.model_executor.models.registry` (spec decode aliases)
# Why:
# Some Eagle3 draft checkpoints may declare a MiniMax-specific architecture
# string while reusing the shared Eagle3 implementation.
# How
# Register `Eagle3MiniMaxM2ForCausalLM` as an alias pointing to the
# existing Eagle3 implementation in the speculative decoding registry.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/37512
# Future Plan:
# Drop the alias once upstream registry includes it or the checkpoint
# standardizes architecture strings.
#
# ** 8. File: platform/patch_kv_cache_interface.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.v1.kv_cache_interface.MLAAttentionSpec`
@@ -453,6 +485,31 @@
# Future Plan:
# Remove this patch when upstream supports MiniMax-M2 fp8 loading on NPU.
#
# 4. `vllm.model_executor.models.minimax_m2.MiniMaxM2Model.forward`
# Why:
# Eagle3 speculative decoding needs auxiliary hidden states from specific
# transformer layers of the target model.
# How
# Extend `MiniMaxM2Model.forward` to optionally collect and return
# `(final_hidden_states, aux_hidden_states)` when `aux_hidden_state_layers`
# is set by the runtime.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/37512
# Future Plan:
# Remove this patch once upstream MiniMax-M2 integrates Eagle3 support.
#
# 5. `vllm.model_executor.models.minimax_m2.MiniMaxM2ForCausalLM`
# Why:
# vLLM core uses SupportsEagle3-style methods to configure which layers
# should emit auxiliary hidden states.
# How
# Inject `set_aux_hidden_state_layers` and default-layer getters onto
# `MiniMaxM2ForCausalLM` so vLLM can configure the target model.
# Related PR (if no, explain why):
# https://github.com/vllm-project/vllm/pull/37512
# Future Plan:
# Remove this patch once upstream provides these methods on the model.
#
# ** 18. File: worker/patch_minimax_m2_linear_attn.py**
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 1. `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.__init__`

View File

@@ -134,3 +134,149 @@ if _original_verify_quantization is not None:
if _original_verify_cuda_graph is not None:
ModelConfig._verify_cuda_graph = _patched_verify_cuda_graph
# ---------------------------------------------------------------------------
# Speculative decoding (Eagle3): allow MiniMax targets and registry alias.
# ---------------------------------------------------------------------------
def _patch_speculative_minimax_whitelist() -> None:
"""Allow MiniMax target models for eagle3/extract_hidden_states checks.
Upstream vLLM validates that the target model_type is in a whitelist for
methods that rely on auxiliary hidden states. Older upstream versions may
not include MiniMax yet.
"""
try:
from vllm.config.speculative import SpeculativeConfig # type: ignore
except Exception:
logger.warning(
"SpeculativeConfig is not found, skip patching eagle3/extract_hidden_states checks for MiniMax-M2 on NPU."
)
return
original_verify_args = getattr(SpeculativeConfig, "_verify_args", None)
if original_verify_args is None:
logger.warning(
"SpeculativeConfig._verify_args is not found, skip patching "
"eagle3/extract_hidden_states checks for MiniMax-M2 on NPU."
)
return
if getattr(original_verify_args, "_vllm_ascend_minimax_eagle3_patched", False):
logger.warning("eagle3/extract_hidden_states checks for MiniMax-M2 on NPU have already been patched.")
return
# Pydantic dataclass validation invokes `model_validators["_verify_args"].func`, not
# necessarily the current `SpeculativeConfig._verify_args` attribute.
decorators = getattr(SpeculativeConfig, "__pydantic_decorators__", None)
mv = None
if decorators is not None:
model_validators = getattr(decorators, "model_validators", None)
if isinstance(model_validators, dict):
mv = model_validators.get("_verify_args")
inner_verify = mv.func if mv is not None and getattr(mv, "func", None) is not None else original_verify_args
def _patched_verify_args(self, *args, **kwargs): # type: ignore[no-untyped-def]
try:
return inner_verify(self, *args, **kwargs)
except ValueError as e:
method = getattr(self, "method", None)
if method not in ("eagle3", "extract_hidden_states"):
raise
target_cfg = getattr(self, "target_model_config", None)
model_type = getattr(getattr(target_cfg, "hf_text_config", None), "model_type", "")
if "minimax" not in str(model_type).lower():
logger.warning(
"Model type %s is not a MiniMax-M2 model, skip eagle3/extract_hidden_states checks.",
model_type,
)
raise
msg = str(e).lower()
if "only supported for" in msg and "models" in msg:
# Upstream `_verify_args` calls `verify_equal_vocab_size_if_draft_model` after
# the aux-hidden allowlist; returning here would skip it.
verify_vocab = getattr(self, "verify_equal_vocab_size_if_draft_model", None)
if callable(verify_vocab):
verify_vocab()
return self
raise
_patched_verify_args._vllm_ascend_minimax_eagle3_patched = True # type: ignore[attr-defined]
SpeculativeConfig._verify_args = _patched_verify_args # type: ignore[assignment]
if mv is not None:
try:
mv.func = _patched_verify_args # type: ignore[misc]
except (TypeError, AttributeError):
object.__setattr__(mv, "func", _patched_verify_args)
else:
logger.warning(
"Could not find SpeculativeConfig.__pydantic_decorators__.model_validators["
"'_verify_args']; eagle3 whitelist patch may not run at init validation."
)
try:
from pydantic.dataclasses import rebuild_dataclass # type: ignore
except Exception as e:
logger.warning(
"Cannot import rebuild_dataclass (%s); SpeculativeConfig eagle3 whitelist "
"patch may not apply at instance construction time.",
e,
)
else:
try:
rebuild_dataclass(SpeculativeConfig, force=True) # type: ignore[arg-type]
except Exception as e:
logger.warning(
"rebuild_dataclass(SpeculativeConfig) failed (%s); eagle3 whitelist patch may not apply.",
e,
)
# If `VllmConfig` was imported before this patch ran, its pydantic-core schema
# for the nested `speculative_config` field may still embed the *pre-patch*
# SpeculativeConfig validators. `create_speculative_config()` calls
# `SpeculativeConfig(...)` directly (uses updated class validator), but
# `VllmConfig(..., speculative_config=...)` validates via the parent's cached
# nested schema and can still raise the whitelist error unless we rebuild.
try:
from vllm.config.vllm import VllmConfig # type: ignore
except Exception:
pass
else:
try:
rebuild_dataclass(VllmConfig, force=True) # type: ignore[arg-type]
except Exception as e:
logger.warning(
"rebuild_dataclass(VllmConfig) failed (%s); VllmConfig(...) may "
"still use stale nested SpeculativeConfig validation.",
e,
)
def _patch_eagle3_registry_alias() -> None:
"""Register Eagle3MiniMaxM2ForCausalLM architecture alias if missing."""
try:
import vllm.model_executor.models.registry as registry # type: ignore
except Exception:
return
# Prefer patching the underlying dicts when available.
if hasattr(registry, "_SPECULATIVE_DECODING_MODELS"):
models = registry._SPECULATIVE_DECODING_MODELS
if isinstance(models, dict):
models.setdefault("Eagle3MiniMaxM2ForCausalLM", ("llama_eagle3", "Eagle3LlamaForCausalLM"))
# Fallback: patch resolved registry instance if present.
model_registry = getattr(registry, "ModelRegistry", None)
if model_registry is not None and hasattr(model_registry, "models"):
try:
model_registry.models.setdefault( # type: ignore[attr-defined]
"Eagle3MiniMaxM2ForCausalLM",
("llama_eagle3", "Eagle3LlamaForCausalLM"),
)
except Exception:
return
_patch_speculative_minimax_whitelist()
_patch_eagle3_registry_alias()

View File

@@ -21,12 +21,19 @@ from collections.abc import Iterable
import torch
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
from vllm.model_executor.models.minimax_m2 import MiniMaxM2Attention, MiniMaxM2Model, MiniMaxM2MoE
from vllm.model_executor.models.minimax_m2 import (
MiniMaxM2Attention,
MiniMaxM2ForCausalLM,
MiniMaxM2Model,
MiniMaxM2MoE,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_slice
@@ -87,6 +94,34 @@ def _patched_attention_init(self, *args, **kwargs) -> None:
MiniMaxM2Attention.__init__ = _patched_attention_init
def _patch_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
cos, sin = get_cos_and_sin_slice()
q, k, v = torch.ops.vllm.split_qkv_tp_rmsnorm_rope(
input=qkv,
q_weight=self.q_norm.weight,
k_weight=self.k_norm.weight,
q_hidden_size=self.q_size,
kv_hidden_size=self.kv_size,
head_dim=self.head_dim,
rotary_dim=getattr(self.rotary_emb, "rotary_dim", self.head_dim),
eps=self.q_norm.variance_epsilon,
tp_world=self.q_norm.tp_world,
cos=cos,
sin=sin,
)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
MiniMaxM2Attention.forward = _patch_forward
# ---------------------------------------------------------------------------
# MiniMaxM2Model: fp8 dequant helpers and load_weights wrapper
# ---------------------------------------------------------------------------
@@ -176,29 +211,79 @@ def _patched_load_weights(
MiniMaxM2Model.load_weights = _patched_load_weights
def _patch_forward(
self,
# ---------------------------------------------------------------------------
# MiniMaxM2Model / MiniMaxM2ForCausalLM: Eagle3 aux hidden states support
# ---------------------------------------------------------------------------
_original_minimax_m2_forward = MiniMaxM2Model.forward
def _patched_minimax_m2_forward(
self: "MiniMaxM2Model",
input_ids: torch.Tensor | None,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
cos, sin = get_cos_and_sin_slice()
q, k, v = torch.ops.vllm.split_qkv_tp_rmsnorm_rope(
input=qkv,
q_weight=self.q_norm.weight,
k_weight=self.k_norm.weight,
q_hidden_size=self.q_size,
kv_hidden_size=self.kv_size,
head_dim=self.head_dim,
rotary_dim=getattr(self.rotary_emb, "rotary_dim", self.head_dim),
eps=self.q_norm.variance_epsilon,
tp_world=self.q_norm.tp_world,
cos=cos,
sin=sin,
)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
aux_layers: tuple[int, ...] = getattr(self, "aux_hidden_state_layers", ()) or ()
if not aux_layers:
return _original_minimax_m2_forward(self, input_ids, positions, intermediate_tensors, inputs_embeds)
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states: list[torch.Tensor] = []
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
layer_idx = self.start_layer + idx
if layer_idx in aux_layers:
aux_hidden_states.append(hidden_states + residual if residual is not None else hidden_states)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states, "residual": residual})
hidden_states, _ = self.norm(hidden_states, residual)
if aux_hidden_states:
return hidden_states, aux_hidden_states
return hidden_states
MiniMaxM2Attention.forward = _patch_forward
if not getattr(_original_minimax_m2_forward, "_vllm_ascend_minimax_eagle3_patched", False):
MiniMaxM2Model.forward = _patched_minimax_m2_forward # type: ignore[assignment]
MiniMaxM2Model.forward._vllm_ascend_minimax_eagle3_patched = True # type: ignore[attr-defined]
def _set_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM", layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = tuple(int(x) for x in layers)
def _get_eagle3_default_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM") -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _get_eagle3_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM") -> tuple[int, ...]:
return _get_eagle3_default_aux_hidden_state_layers(self)
# vLLM 0.18+: `supports_eagle3(model)` is `isinstance(model, SupportsEagle3)` (see
# `vllm.model_executor.models.interfaces`). `SupportsEagle3` extends `SupportsEagleBase`;
# runtime protocol checks require class attributes below (not only Eagle3 methods), or
# isinstance fails and model_runner_v1 raises:
# "Model does not support EAGLE3 interface but aux_hidden_state_outputs was requested".
MiniMaxM2ForCausalLM.has_own_lm_head = False # type: ignore[misc]
MiniMaxM2ForCausalLM.has_own_embed_tokens = False # type: ignore[misc]
MiniMaxM2ForCausalLM.supports_eagle3 = True # type: ignore[misc]
MiniMaxM2ForCausalLM.set_aux_hidden_state_layers = _set_aux_hidden_state_layers # type: ignore[attr-defined]
MiniMaxM2ForCausalLM.get_eagle3_default_aux_hidden_state_layers = ( # type: ignore[attr-defined]
_get_eagle3_default_aux_hidden_state_layers
)
MiniMaxM2ForCausalLM.get_eagle3_aux_hidden_state_layers = _get_eagle3_aux_hidden_state_layers # type: ignore[attr-defined]