Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_draft_quarot.py
drslark 6a7115fa0d [main][feature] Support quarot for eagle3 without embedding (#7038)
### What this PR does / why we need it?
If some `eagle3` model without embed_tokens works with `quarot` target
model, the acceptence rate will drop.
We solve it in this PR.
The relative vllm pr is https://github.com/vllm-project/vllm/pull/36225.

- vLLM main:
4034c3d32e

Signed-off-by: drslark <slarksblood@qq.com>
2026-03-09 10:43:06 +08:00

144 lines
4.6 KiB
Python

import logging
import os
from collections.abc import Iterable
from pathlib import Path
import torch
from safetensors.torch import load_file
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
process_eagle_weight,
)
logger = logging.getLogger(__name__)
def get_embedding_tensor(directory_path):
"""
Scans the directory and returns the first tensor found that contains 'embed' in its key.
Returns the tensor if found, otherwise None.
"""
if not os.path.isdir(directory_path):
return None
# List files and filter for .safetensors
for filename in os.listdir(directory_path):
if filename.endswith(".safetensors"):
file_path = os.path.join(directory_path, filename)
# Load the file
state_dict = load_file(file_path)
# Search for the first matching key
for key, tensor in state_dict.items():
if "embed" in key.lower():
# Return immediately once found
return tensor
return None
def get_rotation_path(target_vllm_config):
"""
Gets the path of the rotation matrix, returns None if the target model is not a quarot model.
"""
target_model_path = target_vllm_config.model_config.model
try:
quant_description = target_vllm_config.quant_config.quant_description
rotation_relative_path = quant_description["optional"]["quarot"]["rotation_map"]["global_rotation"]
except KeyError:
return None
return Path(target_model_path) / rotation_relative_path
def get_rotataion_matrix(rotation_path):
"""
Anti-rotate maxtrix.
"""
try:
safetensor_data = load_file(rotation_path)
Q = safetensor_data["global_rotation"]
return Q
except Exception as e:
logger.error(
f"Failed to load rotation weight from '{rotation_path}'. "
"If you want to use quarot model with eagle3, take a check."
)
raise e
def compute_rotataion_matrix3(Q):
"""
Anti-rotate matrix for 3 layers of hidden_states.
"""
return torch.block_diag(Q, Q, Q)
def patch_load_weights(target_vllm_config):
target_model_path = Path(target_vllm_config.model_config.model)
rotation_path = get_rotation_path(target_vllm_config)
# if rotation path is not found, then quarot is not in use.
if rotation_path is None:
return
Eagle3LlamaForCausalLM.load_weights = make_load_weights(target_model_path, rotation_path)
def make_load_weights(target_model_path, rotation_path):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
Q = get_rotataion_matrix(rotation_path)
Q3 = compute_rotataion_matrix3(Q)
if isinstance(self.config.dtype, str):
embed_dtype = getattr(torch, self.config.dtype)
else:
embed_dtype = self.config.dtype
model_weights = {}
includes_draft_id_mapping = False
includes_embed_tokens = False
for name, loaded_weight in weights:
if "t2d" in name:
continue
if "d2t" in name:
name = name.replace("d2t", "draft_id_to_target_id")
includes_draft_id_mapping = True
elif "lm_head" not in name:
name = "model." + name
if "fc." in name:
# anti-rotate fc
dtype = loaded_weight.dtype
loaded_weight = loaded_weight @ Q3.to(dtype)
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight
process_eagle_weight(self, name)
# process embedding if drafter does not have embedding
if not includes_embed_tokens:
name = "model.embed_tokens.weight"
loaded_weight = get_embedding_tensor(target_model_path).to(embed_dtype) @ Q.T.to(embed_dtype)
model_weights[name] = loaded_weight
includes_embed_tokens = True
process_eagle_weight(self, name)
skip_substrs = []
if not includes_draft_id_mapping:
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
if not self.model.use_aux_hidden_state:
skip_substrs.append("fc.")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
skip_substrs=skip_substrs,
)
loader.load_weights(model_weights.items())
return load_weights