[Bugfix] Fix the acceptance rates dorp issue when applying eagle3 to QuaRot model (#6914)
### What this PR does / why we need it?
When using the target model after rotational quantization, the
acceptance rate decreases because the fc weight of the draft model has
not undergone rotational quantization(issue: #6445). We fixed this issue
by performing rotation quantization on the fc weight of the draft model
in the same way as the main model when loading draft model.
- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
Signed-off-by: zhaomingyu <zhaomingyu13@h-partners.com>
This commit is contained in:
@@ -305,3 +305,14 @@
|
||||
# https://github.com/vllm-project/vllm/pull/34336
|
||||
# Future Plan:
|
||||
# Remove this patch when vLLM merges the PR.
|
||||
# ** 16. File: worker/patch_qwen3_quarot.py**
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# 1. `vllm.model_executor.models.llama_eagle3.Eagle3LlamaForCausalLM.load_weights`
|
||||
# Why:
|
||||
# vllm-ascend reused the loading logic of drafter model from vllm,
|
||||
# but vllm doesn't need to apply to Ascend quantization.
|
||||
# How:
|
||||
# Dynamically replace the `load_weights` function at runtime,
|
||||
# and fix `target_config` into the new implementation with a closure.
|
||||
# Future Plan:
|
||||
# Remove this patch when vLLM merges the PR.
|
||||
|
||||
@@ -35,3 +35,4 @@ import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa
|
||||
import vllm_ascend.patch.worker.patch_routed_experts_capturer # noqa
|
||||
import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa
|
||||
import vllm_ascend.patch.worker.patch_kimi_k25 # noqa
|
||||
import vllm_ascend.patch.worker.patch_qwen3_quarot # noqa
|
||||
|
||||
79
vllm_ascend/patch/worker/patch_qwen3_quarot.py
Normal file
79
vllm_ascend/patch/worker/patch_qwen3_quarot.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.model_executor.models.utils import (
|
||||
AutoWeightsLoader,
|
||||
process_eagle_weight,
|
||||
)
|
||||
|
||||
|
||||
def patch_load_weights(target_config):
|
||||
Eagle3LlamaForCausalLM.load_weights = make_load_weights(target_config)
|
||||
|
||||
|
||||
def make_load_weights(target_config):
|
||||
logger = logging.getLogger(__name__)
|
||||
quant_cfg = target_config.quant_config
|
||||
rotation_matrix3 = None
|
||||
|
||||
model_path = target_config.model_config.model
|
||||
try:
|
||||
rotation_rel_path = quant_cfg.quant_description["optional"]["quarot"]["rotation_map"]["global_rotation"]
|
||||
except KeyError as e:
|
||||
logger.error(
|
||||
"Invalid quant_config: missing key "
|
||||
"quant_description['optional']['quarot']['rotation_map']['global_rotation']. "
|
||||
"If you don't use quarot model, please ignore it. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
else:
|
||||
rotation_path = Path(model_path) / rotation_rel_path
|
||||
try:
|
||||
safetensor_data = load_file(rotation_path)
|
||||
Q = safetensor_data["global_rotation"]
|
||||
rotation_matrix3 = torch.block_diag(Q, Q, Q)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load rotation weight from '{rotation_path}'. "
|
||||
"If you don't use quarot model, please ignore it. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
model_weights = {}
|
||||
includes_draft_id_mapping = False
|
||||
includes_embed_tokens = False
|
||||
for name, loaded_weight in weights:
|
||||
if "t2d" in name:
|
||||
continue
|
||||
if "d2t" in name:
|
||||
name = name.replace("d2t", "draft_id_to_target_id")
|
||||
includes_draft_id_mapping = True
|
||||
elif "lm_head" not in name:
|
||||
name = "model." + name
|
||||
if "fc." in name and rotation_matrix3 is not None:
|
||||
loaded_weight = loaded_weight @ rotation_matrix3.to(loaded_weight.dtype)
|
||||
if "embed_tokens" in name:
|
||||
includes_embed_tokens = True
|
||||
model_weights[name] = loaded_weight
|
||||
process_eagle_weight(self, name)
|
||||
|
||||
skip_substrs = []
|
||||
if not includes_draft_id_mapping:
|
||||
skip_substrs.append("draft_id_to_target_id")
|
||||
if not includes_embed_tokens:
|
||||
skip_substrs.append("embed_tokens")
|
||||
if not self.model.use_aux_hidden_state:
|
||||
skip_substrs.append("fc.")
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=None,
|
||||
skip_substrs=skip_substrs,
|
||||
)
|
||||
loader.load_weights(model_weights.items())
|
||||
|
||||
return load_weights
|
||||
@@ -106,6 +106,7 @@ from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
||||
from vllm_ascend.eplb.utils import model_register
|
||||
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
|
||||
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||||
from vllm_ascend.patch.worker.patch_qwen3_quarot import patch_load_weights
|
||||
from vllm_ascend.sample.sampler import AscendSampler
|
||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
@@ -2422,6 +2423,8 @@ class NPUModelRunner(GPUModelRunner):
|
||||
model_register(self.model)
|
||||
if self.drafter:
|
||||
logger.info("Loading drafter model...")
|
||||
if self.vllm_config.quant_config is not None:
|
||||
patch_load_weights(self.vllm_config)
|
||||
with get_tp_context(self.drafter):
|
||||
self.drafter.load_model(self.model)
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
|
||||
Reference in New Issue
Block a user