[v0.18.0] Apply Eagle3 to MiniMax-M2.5 (#7619) (#7714)

### What this PR does / why we need it?
Apply Eagle3 to MiniMax-M2.5 to increase model performance This will be
discard after Eagle3 weight for MiniMax-M2.5 releases and code change
accepted by official repo
https://github.com/vllm-project/vllm/pull/37512/changes
backport: #7619

- vLLM version: v0.18.0
- vLLM main:
ed359c497a

Signed-off-by: limuyuan <limuyuan3@huawei.com>
Co-authored-by: limuyuan <limuyuan3@huawei.com>
This commit is contained in:
SparrowMu
2026-03-27 18:33:29 +08:00
committed by GitHub
parent 60e88d9541
commit 6fbd0049df
3 changed files with 312 additions and 24 deletions

View File

@@ -21,12 +21,19 @@ from collections.abc import Iterable
import torch
from vllm.distributed import (
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.mamba.linear_attn import MiniMaxText01RMSNormTP
from vllm.model_executor.models.minimax_m2 import MiniMaxM2Attention, MiniMaxM2Model, MiniMaxM2MoE
from vllm.model_executor.models.minimax_m2 import (
MiniMaxM2Attention,
MiniMaxM2ForCausalLM,
MiniMaxM2Model,
MiniMaxM2MoE,
)
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_slice
@@ -87,6 +94,34 @@ def _patched_attention_init(self, *args, **kwargs) -> None:
MiniMaxM2Attention.__init__ = _patched_attention_init
def _patch_forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
cos, sin = get_cos_and_sin_slice()
q, k, v = torch.ops.vllm.split_qkv_tp_rmsnorm_rope(
input=qkv,
q_weight=self.q_norm.weight,
k_weight=self.k_norm.weight,
q_hidden_size=self.q_size,
kv_hidden_size=self.kv_size,
head_dim=self.head_dim,
rotary_dim=getattr(self.rotary_emb, "rotary_dim", self.head_dim),
eps=self.q_norm.variance_epsilon,
tp_world=self.q_norm.tp_world,
cos=cos,
sin=sin,
)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
MiniMaxM2Attention.forward = _patch_forward
# ---------------------------------------------------------------------------
# MiniMaxM2Model: fp8 dequant helpers and load_weights wrapper
# ---------------------------------------------------------------------------
@@ -176,29 +211,79 @@ def _patched_load_weights(
MiniMaxM2Model.load_weights = _patched_load_weights
def _patch_forward(
self,
# ---------------------------------------------------------------------------
# MiniMaxM2Model / MiniMaxM2ForCausalLM: Eagle3 aux hidden states support
# ---------------------------------------------------------------------------
_original_minimax_m2_forward = MiniMaxM2Model.forward
def _patched_minimax_m2_forward(
self: "MiniMaxM2Model",
input_ids: torch.Tensor | None,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
cos, sin = get_cos_and_sin_slice()
q, k, v = torch.ops.vllm.split_qkv_tp_rmsnorm_rope(
input=qkv,
q_weight=self.q_norm.weight,
k_weight=self.k_norm.weight,
q_hidden_size=self.q_size,
kv_hidden_size=self.kv_size,
head_dim=self.head_dim,
rotary_dim=getattr(self.rotary_emb, "rotary_dim", self.head_dim),
eps=self.q_norm.variance_epsilon,
tp_world=self.q_norm.tp_world,
cos=cos,
sin=sin,
)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
aux_layers: tuple[int, ...] = getattr(self, "aux_hidden_state_layers", ()) or ()
if not aux_layers:
return _original_minimax_m2_forward(self, input_ids, positions, intermediate_tensors, inputs_embeds)
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
aux_hidden_states: list[torch.Tensor] = []
for idx, layer in enumerate(self.layers[self.start_layer : self.end_layer]):
layer_idx = self.start_layer + idx
if layer_idx in aux_layers:
aux_hidden_states.append(hidden_states + residual if residual is not None else hidden_states)
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states, "residual": residual})
hidden_states, _ = self.norm(hidden_states, residual)
if aux_hidden_states:
return hidden_states, aux_hidden_states
return hidden_states
MiniMaxM2Attention.forward = _patch_forward
if not getattr(_original_minimax_m2_forward, "_vllm_ascend_minimax_eagle3_patched", False):
MiniMaxM2Model.forward = _patched_minimax_m2_forward # type: ignore[assignment]
MiniMaxM2Model.forward._vllm_ascend_minimax_eagle3_patched = True # type: ignore[attr-defined]
def _set_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM", layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = tuple(int(x) for x in layers)
def _get_eagle3_default_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM") -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def _get_eagle3_aux_hidden_state_layers(self: "MiniMaxM2ForCausalLM") -> tuple[int, ...]:
return _get_eagle3_default_aux_hidden_state_layers(self)
# vLLM 0.18+: `supports_eagle3(model)` is `isinstance(model, SupportsEagle3)` (see
# `vllm.model_executor.models.interfaces`). `SupportsEagle3` extends `SupportsEagleBase`;
# runtime protocol checks require class attributes below (not only Eagle3 methods), or
# isinstance fails and model_runner_v1 raises:
# "Model does not support EAGLE3 interface but aux_hidden_state_outputs was requested".
MiniMaxM2ForCausalLM.has_own_lm_head = False # type: ignore[misc]
MiniMaxM2ForCausalLM.has_own_embed_tokens = False # type: ignore[misc]
MiniMaxM2ForCausalLM.supports_eagle3 = True # type: ignore[misc]
MiniMaxM2ForCausalLM.set_aux_hidden_state_layers = _set_aux_hidden_state_layers # type: ignore[attr-defined]
MiniMaxM2ForCausalLM.get_eagle3_default_aux_hidden_state_layers = ( # type: ignore[attr-defined]
_get_eagle3_default_aux_hidden_state_layers
)
MiniMaxM2ForCausalLM.get_eagle3_aux_hidden_state_layers = _get_eagle3_aux_hidden_state_layers # type: ignore[attr-defined]