### What this PR does / why we need it?
Apply Eagle3 to MiniMax-M2.5 to increase model performance This will be
discard after Eagle3 weight for MiniMax-M2.5 releases and code change
accepted by official repo
https://github.com/vllm-project/vllm/pull/37512/changes
backport: #7619
- vLLM version: v0.18.0
- vLLM main:
ed359c497a
Signed-off-by: limuyuan <limuyuan3@huawei.com>
Co-authored-by: limuyuan <limuyuan3@huawei.com>
This commit is contained in:
@@ -134,3 +134,149 @@ if _original_verify_quantization is not None:
|
||||
|
||||
if _original_verify_cuda_graph is not None:
|
||||
ModelConfig._verify_cuda_graph = _patched_verify_cuda_graph
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Speculative decoding (Eagle3): allow MiniMax targets and registry alias.
|
||||
# ---------------------------------------------------------------------------
|
||||
def _patch_speculative_minimax_whitelist() -> None:
|
||||
"""Allow MiniMax target models for eagle3/extract_hidden_states checks.
|
||||
|
||||
Upstream vLLM validates that the target model_type is in a whitelist for
|
||||
methods that rely on auxiliary hidden states. Older upstream versions may
|
||||
not include MiniMax yet.
|
||||
"""
|
||||
try:
|
||||
from vllm.config.speculative import SpeculativeConfig # type: ignore
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"SpeculativeConfig is not found, skip patching eagle3/extract_hidden_states checks for MiniMax-M2 on NPU."
|
||||
)
|
||||
return
|
||||
|
||||
original_verify_args = getattr(SpeculativeConfig, "_verify_args", None)
|
||||
if original_verify_args is None:
|
||||
logger.warning(
|
||||
"SpeculativeConfig._verify_args is not found, skip patching "
|
||||
"eagle3/extract_hidden_states checks for MiniMax-M2 on NPU."
|
||||
)
|
||||
return
|
||||
if getattr(original_verify_args, "_vllm_ascend_minimax_eagle3_patched", False):
|
||||
logger.warning("eagle3/extract_hidden_states checks for MiniMax-M2 on NPU have already been patched.")
|
||||
return
|
||||
|
||||
# Pydantic dataclass validation invokes `model_validators["_verify_args"].func`, not
|
||||
# necessarily the current `SpeculativeConfig._verify_args` attribute.
|
||||
decorators = getattr(SpeculativeConfig, "__pydantic_decorators__", None)
|
||||
mv = None
|
||||
if decorators is not None:
|
||||
model_validators = getattr(decorators, "model_validators", None)
|
||||
if isinstance(model_validators, dict):
|
||||
mv = model_validators.get("_verify_args")
|
||||
inner_verify = mv.func if mv is not None and getattr(mv, "func", None) is not None else original_verify_args
|
||||
|
||||
def _patched_verify_args(self, *args, **kwargs): # type: ignore[no-untyped-def]
|
||||
try:
|
||||
return inner_verify(self, *args, **kwargs)
|
||||
except ValueError as e:
|
||||
method = getattr(self, "method", None)
|
||||
if method not in ("eagle3", "extract_hidden_states"):
|
||||
raise
|
||||
|
||||
target_cfg = getattr(self, "target_model_config", None)
|
||||
model_type = getattr(getattr(target_cfg, "hf_text_config", None), "model_type", "")
|
||||
if "minimax" not in str(model_type).lower():
|
||||
logger.warning(
|
||||
"Model type %s is not a MiniMax-M2 model, skip eagle3/extract_hidden_states checks.",
|
||||
model_type,
|
||||
)
|
||||
raise
|
||||
|
||||
msg = str(e).lower()
|
||||
if "only supported for" in msg and "models" in msg:
|
||||
# Upstream `_verify_args` calls `verify_equal_vocab_size_if_draft_model` after
|
||||
# the aux-hidden allowlist; returning here would skip it.
|
||||
verify_vocab = getattr(self, "verify_equal_vocab_size_if_draft_model", None)
|
||||
if callable(verify_vocab):
|
||||
verify_vocab()
|
||||
return self
|
||||
raise
|
||||
|
||||
_patched_verify_args._vllm_ascend_minimax_eagle3_patched = True # type: ignore[attr-defined]
|
||||
SpeculativeConfig._verify_args = _patched_verify_args # type: ignore[assignment]
|
||||
|
||||
if mv is not None:
|
||||
try:
|
||||
mv.func = _patched_verify_args # type: ignore[misc]
|
||||
except (TypeError, AttributeError):
|
||||
object.__setattr__(mv, "func", _patched_verify_args)
|
||||
else:
|
||||
logger.warning(
|
||||
"Could not find SpeculativeConfig.__pydantic_decorators__.model_validators["
|
||||
"'_verify_args']; eagle3 whitelist patch may not run at init validation."
|
||||
)
|
||||
|
||||
try:
|
||||
from pydantic.dataclasses import rebuild_dataclass # type: ignore
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Cannot import rebuild_dataclass (%s); SpeculativeConfig eagle3 whitelist "
|
||||
"patch may not apply at instance construction time.",
|
||||
e,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
rebuild_dataclass(SpeculativeConfig, force=True) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"rebuild_dataclass(SpeculativeConfig) failed (%s); eagle3 whitelist patch may not apply.",
|
||||
e,
|
||||
)
|
||||
# If `VllmConfig` was imported before this patch ran, its pydantic-core schema
|
||||
# for the nested `speculative_config` field may still embed the *pre-patch*
|
||||
# SpeculativeConfig validators. `create_speculative_config()` calls
|
||||
# `SpeculativeConfig(...)` directly (uses updated class validator), but
|
||||
# `VllmConfig(..., speculative_config=...)` validates via the parent's cached
|
||||
# nested schema and can still raise the whitelist error unless we rebuild.
|
||||
try:
|
||||
from vllm.config.vllm import VllmConfig # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
rebuild_dataclass(VllmConfig, force=True) # type: ignore[arg-type]
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"rebuild_dataclass(VllmConfig) failed (%s); VllmConfig(...) may "
|
||||
"still use stale nested SpeculativeConfig validation.",
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
def _patch_eagle3_registry_alias() -> None:
|
||||
"""Register Eagle3MiniMaxM2ForCausalLM architecture alias if missing."""
|
||||
try:
|
||||
import vllm.model_executor.models.registry as registry # type: ignore
|
||||
except Exception:
|
||||
return
|
||||
|
||||
# Prefer patching the underlying dicts when available.
|
||||
if hasattr(registry, "_SPECULATIVE_DECODING_MODELS"):
|
||||
models = registry._SPECULATIVE_DECODING_MODELS
|
||||
if isinstance(models, dict):
|
||||
models.setdefault("Eagle3MiniMaxM2ForCausalLM", ("llama_eagle3", "Eagle3LlamaForCausalLM"))
|
||||
|
||||
# Fallback: patch resolved registry instance if present.
|
||||
model_registry = getattr(registry, "ModelRegistry", None)
|
||||
if model_registry is not None and hasattr(model_registry, "models"):
|
||||
try:
|
||||
model_registry.models.setdefault( # type: ignore[attr-defined]
|
||||
"Eagle3MiniMaxM2ForCausalLM",
|
||||
("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
)
|
||||
except Exception:
|
||||
return
|
||||
|
||||
|
||||
_patch_speculative_minimax_whitelist()
|
||||
_patch_eagle3_registry_alias()
|
||||
|
||||
Reference in New Issue
Block a user