From 6a7115fa0df6afab7bb96402167eb959dc295134 Mon Sep 17 00:00:00 2001 From: drslark <96540755+drslark@users.noreply.github.com> Date: Mon, 9 Mar 2026 10:43:06 +0800 Subject: [PATCH] [main][feature] Support quarot for eagle3 without embedding (#7038) ### What this PR does / why we need it? If some `eagle3` model without embed_tokens works with `quarot` target model, the acceptence rate will drop. We solve it in this PR. The relative vllm pr is https://github.com/vllm-project/vllm/pull/36225. - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d Signed-off-by: drslark --- vllm_ascend/patch/__init__.py | 4 +- vllm_ascend/patch/worker/__init__.py | 2 +- .../patch/worker/patch_draft_quarot.py | 143 ++++++++++++++++++ .../patch/worker/patch_qwen3_quarot.py | 79 ---------- vllm_ascend/worker/model_runner_v1.py | 2 +- 5 files changed, 148 insertions(+), 82 deletions(-) create mode 100644 vllm_ascend/patch/worker/patch_draft_quarot.py delete mode 100644 vllm_ascend/patch/worker/patch_qwen3_quarot.py diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index e811c128..7fae512b 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -319,7 +319,7 @@ # https://github.com/vllm-project/vllm/pull/34336 # Future Plan: # Remove this patch when vLLM merges the PR. -# ** 16. File: worker/patch_qwen3_quarot.py** +# ** 16. File: worker/patch_draft_quarot.py** # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 1. `vllm.model_executor.models.llama_eagle3.Eagle3LlamaForCausalLM.load_weights` # Why: @@ -328,5 +328,7 @@ # How: # Dynamically replace the `load_weights` function at runtime, # and fix `target_config` into the new implementation with a closure. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/36225 # Future Plan: # Remove this patch when vLLM merges the PR. diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index 60320de0..f5a50aaf 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -35,4 +35,4 @@ import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa import vllm_ascend.patch.worker.patch_routed_experts_capturer # noqa import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa import vllm_ascend.patch.worker.patch_kimi_k25 # noqa -import vllm_ascend.patch.worker.patch_qwen3_quarot # noqa +import vllm_ascend.patch.worker.patch_draft_quarot # noqa diff --git a/vllm_ascend/patch/worker/patch_draft_quarot.py b/vllm_ascend/patch/worker/patch_draft_quarot.py new file mode 100644 index 00000000..40965893 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_draft_quarot.py @@ -0,0 +1,143 @@ +import logging +import os +from collections.abc import Iterable +from pathlib import Path + +import torch +from safetensors.torch import load_file +from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM +from vllm.model_executor.models.utils import ( + AutoWeightsLoader, + process_eagle_weight, +) + +logger = logging.getLogger(__name__) + + +def get_embedding_tensor(directory_path): + """ + Scans the directory and returns the first tensor found that contains 'embed' in its key. + Returns the tensor if found, otherwise None. + """ + if not os.path.isdir(directory_path): + return None + + # List files and filter for .safetensors + for filename in os.listdir(directory_path): + if filename.endswith(".safetensors"): + file_path = os.path.join(directory_path, filename) + + # Load the file + state_dict = load_file(file_path) + + # Search for the first matching key + for key, tensor in state_dict.items(): + if "embed" in key.lower(): + # Return immediately once found + return tensor + + return None + + +def get_rotation_path(target_vllm_config): + """ + Gets the path of the rotation matrix, returns None if the target model is not a quarot model. + """ + target_model_path = target_vllm_config.model_config.model + try: + quant_description = target_vllm_config.quant_config.quant_description + rotation_relative_path = quant_description["optional"]["quarot"]["rotation_map"]["global_rotation"] + except KeyError: + return None + + return Path(target_model_path) / rotation_relative_path + + +def get_rotataion_matrix(rotation_path): + """ + Anti-rotate maxtrix. + """ + try: + safetensor_data = load_file(rotation_path) + Q = safetensor_data["global_rotation"] + + return Q + except Exception as e: + logger.error( + f"Failed to load rotation weight from '{rotation_path}'. " + "If you want to use quarot model with eagle3, take a check." + ) + raise e + + +def compute_rotataion_matrix3(Q): + """ + Anti-rotate matrix for 3 layers of hidden_states. + """ + return torch.block_diag(Q, Q, Q) + + +def patch_load_weights(target_vllm_config): + target_model_path = Path(target_vllm_config.model_config.model) + rotation_path = get_rotation_path(target_vllm_config) + + # if rotation path is not found, then quarot is not in use. + if rotation_path is None: + return + + Eagle3LlamaForCausalLM.load_weights = make_load_weights(target_model_path, rotation_path) + + +def make_load_weights(target_model_path, rotation_path): + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + Q = get_rotataion_matrix(rotation_path) + Q3 = compute_rotataion_matrix3(Q) + if isinstance(self.config.dtype, str): + embed_dtype = getattr(torch, self.config.dtype) + else: + embed_dtype = self.config.dtype + + model_weights = {} + includes_draft_id_mapping = False + includes_embed_tokens = False + for name, loaded_weight in weights: + if "t2d" in name: + continue + if "d2t" in name: + name = name.replace("d2t", "draft_id_to_target_id") + includes_draft_id_mapping = True + elif "lm_head" not in name: + name = "model." + name + if "fc." in name: + # anti-rotate fc + dtype = loaded_weight.dtype + loaded_weight = loaded_weight @ Q3.to(dtype) + if "embed_tokens" in name: + includes_embed_tokens = True + model_weights[name] = loaded_weight + process_eagle_weight(self, name) + + # process embedding if drafter does not have embedding + if not includes_embed_tokens: + name = "model.embed_tokens.weight" + loaded_weight = get_embedding_tensor(target_model_path).to(embed_dtype) @ Q.T.to(embed_dtype) + model_weights[name] = loaded_weight + + includes_embed_tokens = True + process_eagle_weight(self, name) + + skip_substrs = [] + if not includes_draft_id_mapping: + skip_substrs.append("draft_id_to_target_id") + if not includes_embed_tokens: + skip_substrs.append("embed_tokens") + if not self.model.use_aux_hidden_state: + skip_substrs.append("fc.") + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + skip_substrs=skip_substrs, + ) + loader.load_weights(model_weights.items()) + + return load_weights diff --git a/vllm_ascend/patch/worker/patch_qwen3_quarot.py b/vllm_ascend/patch/worker/patch_qwen3_quarot.py deleted file mode 100644 index 3780c664..00000000 --- a/vllm_ascend/patch/worker/patch_qwen3_quarot.py +++ /dev/null @@ -1,79 +0,0 @@ -import logging -from collections.abc import Iterable -from pathlib import Path - -import torch -from safetensors.torch import load_file -from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.model_executor.models.utils import ( - AutoWeightsLoader, - process_eagle_weight, -) - - -def patch_load_weights(target_config): - Eagle3LlamaForCausalLM.load_weights = make_load_weights(target_config) - - -def make_load_weights(target_config): - logger = logging.getLogger(__name__) - quant_cfg = target_config.quant_config - rotation_matrix3 = None - - model_path = target_config.model_config.model - try: - rotation_rel_path = quant_cfg.quant_description["optional"]["quarot"]["rotation_map"]["global_rotation"] - except KeyError as e: - logger.error( - "Invalid quant_config: missing key " - "quant_description['optional']['quarot']['rotation_map']['global_rotation']. " - "If you don't use quarot model, please ignore it. " - f"Error: {e}" - ) - else: - rotation_path = Path(model_path) / rotation_rel_path - try: - safetensor_data = load_file(rotation_path) - Q = safetensor_data["global_rotation"] - rotation_matrix3 = torch.block_diag(Q, Q, Q) - except Exception as e: - logger.error( - f"Failed to load rotation weight from '{rotation_path}'. " - "If you don't use quarot model, please ignore it. " - f"Error: {e}" - ) - - def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): - model_weights = {} - includes_draft_id_mapping = False - includes_embed_tokens = False - for name, loaded_weight in weights: - if "t2d" in name: - continue - if "d2t" in name: - name = name.replace("d2t", "draft_id_to_target_id") - includes_draft_id_mapping = True - elif "lm_head" not in name: - name = "model." + name - if "fc." in name and rotation_matrix3 is not None: - loaded_weight = loaded_weight @ rotation_matrix3.to(loaded_weight.dtype) - if "embed_tokens" in name: - includes_embed_tokens = True - model_weights[name] = loaded_weight - process_eagle_weight(self, name) - - skip_substrs = [] - if not includes_draft_id_mapping: - skip_substrs.append("draft_id_to_target_id") - if not includes_embed_tokens: - skip_substrs.append("embed_tokens") - if not self.model.use_aux_hidden_state: - skip_substrs.append("fc.") - loader = AutoWeightsLoader( - self, - skip_prefixes=None, - skip_substrs=skip_substrs, - ) - loader.load_weights(model_weights.items()) - - return load_weights diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 64f4c22d..21e4b44e 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -103,8 +103,8 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess from vllm_ascend.eplb.eplb_updator import EplbUpdator from vllm_ascend.eplb.utils import model_register from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin +from vllm_ascend.patch.worker.patch_draft_quarot import patch_load_weights from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort -from vllm_ascend.patch.worker.patch_qwen3_quarot import patch_load_weights from vllm_ascend.sample.sampler import AscendSampler from vllm_ascend.spec_decode import get_spec_decode_method from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer