### What this PR does / why we need it?
Apply Eagle3 to MiniMax-M2.5 to increase model performance This will be
discard after Eagle3 weight for MiniMax-M2.5 releases and code change
accepted by official repo
https://github.com/vllm-project/vllm/pull/37512/changes
backport: #7619
- vLLM version: v0.18.0
- vLLM main:
ed359c497a
Signed-off-by: limuyuan <limuyuan3@huawei.com>
Co-authored-by: limuyuan <limuyuan3@huawei.com>
This commit is contained in:
@@ -137,6 +137,38 @@
|
||||
# Remove this patch if upstream provides an official NPU graph-capture
|
||||
# guidance / auto-configuration path for HCCL.
|
||||
#
|
||||
# 3. `vllm.config.speculative.SpeculativeConfig._verify_args`
|
||||
# Why:
|
||||
# Upstream vLLM's eagle3/extract_hidden_states restricts target model types
|
||||
# via a whitelist. MiniMax-M2 should be allowed once the worker-side model
|
||||
# can emit auxiliary hidden states.
|
||||
# How:
|
||||
# Monkey-patch `_verify_args` to bypass only the whitelist ValueError for
|
||||
# MiniMax model_type when method is eagle3/extract_hidden_states.
|
||||
# SpeculativeConfig is a Pydantic dataclass (`@config`); init validation calls
|
||||
# `__pydantic_decorators__.model_validators["_verify_args"].func`, so that
|
||||
# `Decorator.func` must be replaced (not only `SpeculativeConfig._verify_args`),
|
||||
# then `rebuild_dataclass(SpeculativeConfig, force=True)`.
|
||||
# If `VllmConfig` was imported earlier, also `rebuild_dataclass(VllmConfig, ...)`
|
||||
# so nested `speculative_config` validation does not use a stale schema.
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/37512
|
||||
# Future Plan:
|
||||
# Remove this patch once upstream whitelist includes MiniMax.
|
||||
#
|
||||
# 4. `vllm.model_executor.models.registry` (spec decode aliases)
|
||||
# Why:
|
||||
# Some Eagle3 draft checkpoints may declare a MiniMax-specific architecture
|
||||
# string while reusing the shared Eagle3 implementation.
|
||||
# How:
|
||||
# Register `Eagle3MiniMaxM2ForCausalLM` as an alias pointing to the
|
||||
# existing Eagle3 implementation in the speculative decoding registry.
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/37512
|
||||
# Future Plan:
|
||||
# Drop the alias once upstream registry includes it or the checkpoint
|
||||
# standardizes architecture strings.
|
||||
#
|
||||
# ** 8. File: platform/patch_kv_cache_interface.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.v1.kv_cache_interface.MLAAttentionSpec`
|
||||
@@ -453,6 +485,31 @@
|
||||
# Future Plan:
|
||||
# Remove this patch when upstream supports MiniMax-M2 fp8 loading on NPU.
|
||||
#
|
||||
# 4. `vllm.model_executor.models.minimax_m2.MiniMaxM2Model.forward`
|
||||
# Why:
|
||||
# Eagle3 speculative decoding needs auxiliary hidden states from specific
|
||||
# transformer layers of the target model.
|
||||
# How:
|
||||
# Extend `MiniMaxM2Model.forward` to optionally collect and return
|
||||
# `(final_hidden_states, aux_hidden_states)` when `aux_hidden_state_layers`
|
||||
# is set by the runtime.
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/37512
|
||||
# Future Plan:
|
||||
# Remove this patch once upstream MiniMax-M2 integrates Eagle3 support.
|
||||
#
|
||||
# 5. `vllm.model_executor.models.minimax_m2.MiniMaxM2ForCausalLM`
|
||||
# Why:
|
||||
# vLLM core uses SupportsEagle3-style methods to configure which layers
|
||||
# should emit auxiliary hidden states.
|
||||
# How:
|
||||
# Inject `set_aux_hidden_state_layers` and default-layer getters onto
|
||||
# `MiniMaxM2ForCausalLM` so vLLM can configure the target model.
|
||||
# Related PR (if no, explain why):
|
||||
# https://github.com/vllm-project/vllm/pull/37512
|
||||
# Future Plan:
|
||||
# Remove this patch once upstream provides these methods on the model.
|
||||
#
|
||||
# ** 18. File: worker/patch_minimax_m2_linear_attn.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.layers.mamba.linear_attn.MiniMaxText01RMSNormTP.__init__`
|
||||
|
||||
@@ -134,3 +134,149 @@ if _original_verify_quantization is not None:
|
||||
|
||||
if _original_verify_cuda_graph is not None:
|
||||
ModelConfig._verify_cuda_graph = _patched_verify_cuda_graph
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Speculative decoding (Eagle3): allow MiniMax targets and registry alias.
|
||||
# ---------------------------------------------------------------------------
|
||||
def _patch_speculative_minimax_whitelist() -> None:
|
||||
"""Allow MiniMax target models for eagle3/extract_hidden_states checks.
|
||||
|
||||
Upstream vLLM validates that the target model_type is in a whitelist for
|
||||
methods that rely on auxiliary hidden states. Older upstream versions may
|
||||
not include MiniMax yet.
|
||||
"""
|
||||
try:
|
||||
from vllm.config.speculative import SpeculativeConfig # type: ignore
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"SpeculativeConfig is not found, skip patching eagle3/extract_hidden_states checks for MiniMax-M2 on NPU."
|
||||
)
|
||||
return
|
||||
|
||||
original_verify_args = getattr(SpeculativeConfig, "_verify_args", None)
|
||||
if original_verify_args is None:
|
||||
logger.warning(
|
||||
"SpeculativeConfig._verify_args is not found, skip patching "
|
||||
"eagle3/extract_hidden_states checks for MiniMax-M2 on NPU."
|
||||
)
|
||||
return
|
||||
if getattr(original_verify_args, "_vllm_ascend_minimax_eagle3_patched", False):
|
||||
logger.warning("eagle3/extract_hidden_states checks for MiniMax-M2 on NPU have already been patched.")
|
||||
return
|
||||
|
||||
# Pydantic dataclass validation invokes `model_validators["_verify_args"].func`, not
|
||||
# necessarily the current `SpeculativeConfig._verify_args` attribute.
|
||||
decorators = getattr(SpeculativeConfig, "__pydantic_decorators__", None)
|
||||
mv = None
|
||||
if decorators is not None:
|
||||
model_validators = getattr(decorators, "model_validators", None)
|
||||
if isinstance(model_validators, dict):
|
||||
mv = model_validators.get("_verify_args")
|
||||
inner_verify = mv.func if mv is not None and getattr(mv, "func", None) is not None else original_verify_args
|
||||
|
||||
def _patched_verify_args(self, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
try:
|
||||
return inner_verify(self, *args, **kwargs)
|
||||
except ValueError as e:
|
||||
method = getattr(self, "method", None)
|
||||
if method not in ("eagle3", "extract_hidden_states"):
|
||||
raise
|
||||
|
||||
target_cfg = getattr(self, "target_model_config", None)
|
||||
model_type = getattr(getattr(target_cfg, "hf_text_config", None), "model_type", "")
|
||||
if "minimax" not in str(model_type).lower():
|
||||
logger.warning(
|
||||
"Model type %s is not a MiniMax-M2 model, skip eagle3/extract_hidden_states checks.",
|
||||
model_type,
|
||||
)
|
||||
raise
|
||||
|
||||
msg = str(e).lower()
|
||||
if "only supported for" in msg and "models" in msg:
|
||||
# Upstream `_verify_args` calls `verify_equal_vocab_size_if_draft_model` after
|
||||
# the aux-hidden allowlist; returning here would skip it.
|
||||
verify_vocab = getattr(self, "verify_equal_vocab_size_if_draft_model", None)
|
||||
if callable(verify_vocab):
|
||||
verify_vocab()
|
||||
return self
|
||||
raise
|
||||
|
||||
_patched_verify_args._vllm_ascend_minimax_eagle3_patched = True # type: ignore[attr-defined]
|
||||
SpeculativeConfig._verify_args = _patched_verify_args # type: ignore[assignment]
|
||||
|
||||
if mv is not None:
|
||||
try:
|
||||
mv.func = _patched_verify_args # type: ignore[misc]
|
||||
except (TypeError, AttributeError):
|
||||
object.__setattr__(mv, "func", _patched_verify_args)
|
||||
else:
|
||||
logger.warning(
|
||||
"Could not find SpeculativeConfig.__pydantic_decorators__.model_validators["
|
||||
"'_verify_args']; eagle3 whitelist patch may not run at init validation."
|
||||
)
|
||||
|
||||
try:
|
||||
from pydantic.dataclasses import rebuild_dataclass # type: ignore
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Cannot import rebuild_dataclass (%s); SpeculativeConfig eagle3 whitelist "
|
||||
"patch may not apply at instance construction time.",
|
||||
e,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
rebuild_dataclass(SpeculativeConfig, force=True) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"rebuild_dataclass(SpeculativeConfig) failed (%s); eagle3 whitelist patch may not apply.",
|
||||
e,
|
||||
)
|
||||
# If `VllmConfig` was imported before this patch ran, its pydantic-core schema
|
||||
# for the nested `speculative_config` field may still embed the *pre-patch*
|
||||
# SpeculativeConfig validators. `create_speculative_config()` calls
|
||||
# `SpeculativeConfig(...)` directly (uses updated class validator), but
|
||||
# `VllmConfig(..., speculative_config=...)` validates via the parent's cached
|
||||
# nested schema and can still raise the whitelist error unless we rebuild.
|
||||
try:
|
||||
from vllm.config.vllm import VllmConfig # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
rebuild_dataclass(VllmConfig, force=True) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"rebuild_dataclass(VllmConfig) failed (%s); VllmConfig(...) may "
|
||||
"still use stale nested SpeculativeConfig validation.",
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
def _patch_eagle3_registry_alias() -> None:
|
||||
"""Register Eagle3MiniMaxM2ForCausalLM architecture alias if missing."""
|
||||
try:
|
||||
import vllm.model_executor.models.registry as registry # type: ignore
|
||||
except Exception:
|
||||
return
|
||||
|
||||
# Prefer patching the underlying dicts when available.
|
||||
if hasattr(registry, "_SPECULATIVE_DECODING_MODELS"):
|
||||
models = registry._SPECULATIVE_DECODING_MODELS
|
||||
if isinstance(models, dict):
|
||||
models.setdefault("Eagle3MiniMaxM2ForCausalLM", ("llama_eagle3", "Eagle3LlamaForCausalLM"))
|
||||
|
||||
# Fallback: patch resolved registry instance if present.
|
||||
model_registry = getattr(registry, "ModelRegistry", None)
|
||||
if model_registry is not None and hasattr(model_registry, "models"):
|
||||
try:
|
||||
model_registry.models.setdefault( # type: ignore[attr-defined]
|
||||
"Eagle3MiniMaxM2ForCausalLM",
|
||||
("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
_patch_speculative_minimax_whitelist()
|
||||
_patch_eagle3_registry_alias()
|
||||
|
||||
@@ -21,12 +21,19 @@ from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from vllm.distributed import (
|
||||
get_pp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
|
||||
from vllm.model_executor.models.minimax_m2 import MiniMaxM2Attention, MiniMaxM2Model, MiniMaxM2MoE
|
||||
from vllm.model_executor.models.minimax_m2 import (
|
||||
MiniMaxM2Attention,
|
||||
MiniMaxM2ForCausalLM,
|
||||
MiniMaxM2Model,
|
||||
MiniMaxM2MoE,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_slice
|
||||
|
||||
@@ -87,6 +94,34 @@ def _patched_attention_init(self, *args, **kwargs) -> None:
|
||||
MiniMaxM2Attention.__init__ = _patched_attention_init
|
||||
|
||||
|
||||
def _patch_forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
cos, sin = get_cos_and_sin_slice()
|
||||
q, k, v = torch.ops.vllm.split_qkv_tp_rmsnorm_rope(
|
||||
input=qkv,
|
||||
q_weight=self.q_norm.weight,
|
||||
k_weight=self.k_norm.weight,
|
||||
q_hidden_size=self.q_size,
|
||||
kv_hidden_size=self.kv_size,
|
||||
head_dim=self.head_dim,
|
||||
rotary_dim=getattr(self.rotary_emb, "rotary_dim", self.head_dim),
|
||||
eps=self.q_norm.variance_epsilon,
|
||||
tp_world=self.q_norm.tp_world,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
MiniMaxM2Attention.forward = _patch_forward
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MiniMaxM2Model: fp8 dequant helpers and load_weights wrapper
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -176,29 +211,79 @@ def _patched_load_weights(
|
||||
MiniMaxM2Model.load_weights = _patched_load_weights
|
||||
|
||||
|
||||
def _patch_forward(
|
||||
self,
|
||||
# ---------------------------------------------------------------------------
|
||||
# MiniMaxM2Model / MiniMaxM2ForCausalLM: Eagle3 aux hidden states support
|
||||
# ---------------------------------------------------------------------------
|
||||
_original_minimax_m2_forward = MiniMaxM2Model.forward
|
||||
|
||||
|
||||
def _patched_minimax_m2_forward(
|
||||
self: "MiniMaxM2Model",
|
||||
input_ids: torch.Tensor | None,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
cos, sin = get_cos_and_sin_slice()
|
||||
q, k, v = torch.ops.vllm.split_qkv_tp_rmsnorm_rope(
|
||||
input=qkv,
|
||||
q_weight=self.q_norm.weight,
|
||||
k_weight=self.k_norm.weight,
|
||||
q_hidden_size=self.q_size,
|
||||
kv_hidden_size=self.kv_size,
|
||||
head_dim=self.head_dim,
|
||||
rotary_dim=getattr(self.rotary_emb, "rotary_dim", self.head_dim),
|
||||
eps=self.q_norm.variance_epsilon,
|
||||
tp_world=self.q_norm.tp_world,
|
||||
cos=cos,
|
||||
sin=sin,
|
||||
)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
aux_layers: tuple[int, ...] = getattr(self, "aux_hidden_state_layers", ()) or ()
|
||||
if not aux_layers:
|
||||
return _original_minimax_m2_forward(self, input_ids, positions, intermediate_tensors, inputs_embeds)
|
||||
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_input_ids(input_ids)
|
||||
residual = None
|
||||
else:
|
||||
assert intermediate_tensors is not None
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
aux_hidden_states: list[torch.Tensor] = []
|
||||
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
|
||||
layer_idx = self.start_layer + idx
|
||||
if layer_idx in aux_layers:
|
||||
aux_hidden_states.append(hidden_states + residual if residual is not None else hidden_states)
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors({"hidden_states": hidden_states, "residual": residual})
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
if aux_hidden_states:
|
||||
return hidden_states, aux_hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
MiniMaxM2Attention.forward = _patch_forward
|
||||
if not getattr(_original_minimax_m2_forward, "_vllm_ascend_minimax_eagle3_patched", False):
|
||||
MiniMaxM2Model.forward = _patched_minimax_m2_forward # type: ignore[assignment]
|
||||
MiniMaxM2Model.forward._vllm_ascend_minimax_eagle3_patched = True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _set_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM", layers: tuple[int, ...]) -> None:
|
||||
self.model.aux_hidden_state_layers = tuple(int(x) for x in layers)
|
||||
|
||||
|
||||
def _get_eagle3_default_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM") -> tuple[int, ...]:
|
||||
num_layers = len(self.model.layers)
|
||||
return (2, num_layers // 2, num_layers - 3)
|
||||
|
||||
|
||||
def _get_eagle3_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM") -> tuple[int, ...]:
|
||||
return _get_eagle3_default_aux_hidden_state_layers(self)
|
||||
|
||||
|
||||
# vLLM 0.18+: `supports_eagle3(model)` is `isinstance(model, SupportsEagle3)` (see
|
||||
# `vllm.model_executor.models.interfaces`). `SupportsEagle3` extends `SupportsEagleBase`;
|
||||
# runtime protocol checks require class attributes below (not only Eagle3 methods), or
|
||||
# isinstance fails and model_runner_v1 raises:
|
||||
# "Model does not support EAGLE3 interface but aux_hidden_state_outputs was requested".
|
||||
MiniMaxM2ForCausalLM.has_own_lm_head = False # type: ignore[misc]
|
||||
MiniMaxM2ForCausalLM.has_own_embed_tokens = False # type: ignore[misc]
|
||||
MiniMaxM2ForCausalLM.supports_eagle3 = True # type: ignore[misc]
|
||||
|
||||
MiniMaxM2ForCausalLM.set_aux_hidden_state_layers = _set_aux_hidden_state_layers # type: ignore[attr-defined]
|
||||
MiniMaxM2ForCausalLM.get_eagle3_default_aux_hidden_state_layers = ( # type: ignore[attr-defined]
|
||||
_get_eagle3_default_aux_hidden_state_layers
|
||||
)
|
||||
MiniMaxM2ForCausalLM.get_eagle3_aux_hidden_state_layers = _get_eagle3_aux_hidden_state_layers # type: ignore[attr-defined]
|
||||
|
||||
Reference in New Issue
Block a user