Files
xc-llm-ascend/vllm_ascend/patch/worker/patch_qwen3_quarot.py
zhaomingyu13 52d9086f64 [Bugfix] Fix the acceptance rates dorp issue when applying eagle3 to QuaRot model (#6914)
### What this PR does / why we need it?
When using the target model after rotational quantization, the
acceptance rate decreases because the fc weight of the draft model has
not undergone rotational quantization(issue: #6445). We fixed this issue
by performing rotation quantization on the fc weight of the draft model
in the same way as the main model when loading draft model.

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2

Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
2026-03-04 11:29:49 +08:00

80 lines
2.8 KiB
Python

import logging
from collections.abc import Iterable
from pathlib import Path
import torch
from safetensors.torch import load_file
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.model_executor.models.utils import (
AutoWeightsLoader,
process_eagle_weight,
)
def patch_load_weights(target_config):
Eagle3LlamaForCausalLM.load_weights = make_load_weights(target_config)
def make_load_weights(target_config):
logger = logging.getLogger(__name__)
quant_cfg = target_config.quant_config
rotation_matrix3 = None
model_path = target_config.model_config.model
try:
rotation_rel_path = quant_cfg.quant_description["optional"]["quarot"]["rotation_map"]["global_rotation"]
except KeyError as e:
logger.error(
"Invalid quant_config: missing key "
"quant_description['optional']['quarot']['rotation_map']['global_rotation']. "
"If you don't use quarot model, please ignore it. "
f"Error: {e}"
)
else:
rotation_path = Path(model_path) / rotation_rel_path
try:
safetensor_data = load_file(rotation_path)
Q = safetensor_data["global_rotation"]
rotation_matrix3 = torch.block_diag(Q, Q, Q)
except Exception as e:
logger.error(
f"Failed to load rotation weight from '{rotation_path}'. "
"If you don't use quarot model, please ignore it. "
f"Error: {e}"
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
model_weights = {}
includes_draft_id_mapping = False
includes_embed_tokens = False
for name, loaded_weight in weights:
if "t2d" in name:
continue
if "d2t" in name:
name = name.replace("d2t", "draft_id_to_target_id")
includes_draft_id_mapping = True
elif "lm_head" not in name:
name = "model." + name
if "fc." in name and rotation_matrix3 is not None:
loaded_weight = loaded_weight @ rotation_matrix3.to(loaded_weight.dtype)
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight
process_eagle_weight(self, name)
skip_substrs = []
if not includes_draft_id_mapping:
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
if not self.model.use_aux_hidden_state:
skip_substrs.append("fc.")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
skip_substrs=skip_substrs,
)
loader.load_weights(model_weights.items())
return load_weights