[main][feature] Support quarot for eagle3 without embedding (#7038)
### What this PR does / why we need it?
If some `eagle3` model without embed_tokens works with `quarot` target
model, the acceptence rate will drop.
We solve it in this PR.
The relative vllm pr is https://github.com/vllm-project/vllm/pull/36225.
- vLLM main:
4034c3d32e
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -319,7 +319,7 @@
|
|||||||
# https://github.com/vllm-project/vllm/pull/34336
|
# https://github.com/vllm-project/vllm/pull/34336
|
||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Remove this patch when vLLM merges the PR.
|
# Remove this patch when vLLM merges the PR.
|
||||||
# ** 16. File: worker/patch_qwen3_quarot.py**
|
# ** 16. File: worker/patch_draft_quarot.py**
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# 1. `vllm.model_executor.models.llama_eagle3.Eagle3LlamaForCausalLM.load_weights`
|
# 1. `vllm.model_executor.models.llama_eagle3.Eagle3LlamaForCausalLM.load_weights`
|
||||||
# Why:
|
# Why:
|
||||||
@@ -328,5 +328,7 @@
|
|||||||
# How:
|
# How:
|
||||||
# Dynamically replace the `load_weights` function at runtime,
|
# Dynamically replace the `load_weights` function at runtime,
|
||||||
# and fix `target_config` into the new implementation with a closure.
|
# and fix `target_config` into the new implementation with a closure.
|
||||||
|
# Related PR (if no, explain why):
|
||||||
|
# https://github.com/vllm-project/vllm/pull/36225
|
||||||
# Future Plan:
|
# Future Plan:
|
||||||
# Remove this patch when vLLM merges the PR.
|
# Remove this patch when vLLM merges the PR.
|
||||||
|
|||||||
@@ -35,4 +35,4 @@ import vllm_ascend.patch.worker.patch_huanyuan_vl # noqa
|
|||||||
import vllm_ascend.patch.worker.patch_routed_experts_capturer # noqa
|
import vllm_ascend.patch.worker.patch_routed_experts_capturer # noqa
|
||||||
import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa
|
import vllm_ascend.patch.worker.patch_npugraph_ex_triton # noqa
|
||||||
import vllm_ascend.patch.worker.patch_kimi_k25 # noqa
|
import vllm_ascend.patch.worker.patch_kimi_k25 # noqa
|
||||||
import vllm_ascend.patch.worker.patch_qwen3_quarot # noqa
|
import vllm_ascend.patch.worker.patch_draft_quarot # noqa
|
||||||
|
|||||||
143
vllm_ascend/patch/worker/patch_draft_quarot.py
Normal file
143
vllm_ascend/patch/worker/patch_draft_quarot.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
|
from vllm.model_executor.models.utils import (
|
||||||
|
AutoWeightsLoader,
|
||||||
|
process_eagle_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_tensor(directory_path):
|
||||||
|
"""
|
||||||
|
Scans the directory and returns the first tensor found that contains 'embed' in its key.
|
||||||
|
Returns the tensor if found, otherwise None.
|
||||||
|
"""
|
||||||
|
if not os.path.isdir(directory_path):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# List files and filter for .safetensors
|
||||||
|
for filename in os.listdir(directory_path):
|
||||||
|
if filename.endswith(".safetensors"):
|
||||||
|
file_path = os.path.join(directory_path, filename)
|
||||||
|
|
||||||
|
# Load the file
|
||||||
|
state_dict = load_file(file_path)
|
||||||
|
|
||||||
|
# Search for the first matching key
|
||||||
|
for key, tensor in state_dict.items():
|
||||||
|
if "embed" in key.lower():
|
||||||
|
# Return immediately once found
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_rotation_path(target_vllm_config):
|
||||||
|
"""
|
||||||
|
Gets the path of the rotation matrix, returns None if the target model is not a quarot model.
|
||||||
|
"""
|
||||||
|
target_model_path = target_vllm_config.model_config.model
|
||||||
|
try:
|
||||||
|
quant_description = target_vllm_config.quant_config.quant_description
|
||||||
|
rotation_relative_path = quant_description["optional"]["quarot"]["rotation_map"]["global_rotation"]
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return Path(target_model_path) / rotation_relative_path
|
||||||
|
|
||||||
|
|
||||||
|
def get_rotataion_matrix(rotation_path):
|
||||||
|
"""
|
||||||
|
Anti-rotate maxtrix.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
safetensor_data = load_file(rotation_path)
|
||||||
|
Q = safetensor_data["global_rotation"]
|
||||||
|
|
||||||
|
return Q
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to load rotation weight from '{rotation_path}'. "
|
||||||
|
"If you want to use quarot model with eagle3, take a check."
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def compute_rotataion_matrix3(Q):
|
||||||
|
"""
|
||||||
|
Anti-rotate matrix for 3 layers of hidden_states.
|
||||||
|
"""
|
||||||
|
return torch.block_diag(Q, Q, Q)
|
||||||
|
|
||||||
|
|
||||||
|
def patch_load_weights(target_vllm_config):
|
||||||
|
target_model_path = Path(target_vllm_config.model_config.model)
|
||||||
|
rotation_path = get_rotation_path(target_vllm_config)
|
||||||
|
|
||||||
|
# if rotation path is not found, then quarot is not in use.
|
||||||
|
if rotation_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
Eagle3LlamaForCausalLM.load_weights = make_load_weights(target_model_path, rotation_path)
|
||||||
|
|
||||||
|
|
||||||
|
def make_load_weights(target_model_path, rotation_path):
|
||||||
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
|
Q = get_rotataion_matrix(rotation_path)
|
||||||
|
Q3 = compute_rotataion_matrix3(Q)
|
||||||
|
if isinstance(self.config.dtype, str):
|
||||||
|
embed_dtype = getattr(torch, self.config.dtype)
|
||||||
|
else:
|
||||||
|
embed_dtype = self.config.dtype
|
||||||
|
|
||||||
|
model_weights = {}
|
||||||
|
includes_draft_id_mapping = False
|
||||||
|
includes_embed_tokens = False
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if "t2d" in name:
|
||||||
|
continue
|
||||||
|
if "d2t" in name:
|
||||||
|
name = name.replace("d2t", "draft_id_to_target_id")
|
||||||
|
includes_draft_id_mapping = True
|
||||||
|
elif "lm_head" not in name:
|
||||||
|
name = "model." + name
|
||||||
|
if "fc." in name:
|
||||||
|
# anti-rotate fc
|
||||||
|
dtype = loaded_weight.dtype
|
||||||
|
loaded_weight = loaded_weight @ Q3.to(dtype)
|
||||||
|
if "embed_tokens" in name:
|
||||||
|
includes_embed_tokens = True
|
||||||
|
model_weights[name] = loaded_weight
|
||||||
|
process_eagle_weight(self, name)
|
||||||
|
|
||||||
|
# process embedding if drafter does not have embedding
|
||||||
|
if not includes_embed_tokens:
|
||||||
|
name = "model.embed_tokens.weight"
|
||||||
|
loaded_weight = get_embedding_tensor(target_model_path).to(embed_dtype) @ Q.T.to(embed_dtype)
|
||||||
|
model_weights[name] = loaded_weight
|
||||||
|
|
||||||
|
includes_embed_tokens = True
|
||||||
|
process_eagle_weight(self, name)
|
||||||
|
|
||||||
|
skip_substrs = []
|
||||||
|
if not includes_draft_id_mapping:
|
||||||
|
skip_substrs.append("draft_id_to_target_id")
|
||||||
|
if not includes_embed_tokens:
|
||||||
|
skip_substrs.append("embed_tokens")
|
||||||
|
if not self.model.use_aux_hidden_state:
|
||||||
|
skip_substrs.append("fc.")
|
||||||
|
loader = AutoWeightsLoader(
|
||||||
|
self,
|
||||||
|
skip_prefixes=None,
|
||||||
|
skip_substrs=skip_substrs,
|
||||||
|
)
|
||||||
|
loader.load_weights(model_weights.items())
|
||||||
|
|
||||||
|
return load_weights
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
import logging
|
|
||||||
from collections.abc import Iterable
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from safetensors.torch import load_file
|
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
|
||||||
from vllm.model_executor.models.utils import (
|
|
||||||
AutoWeightsLoader,
|
|
||||||
process_eagle_weight,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def patch_load_weights(target_config):
|
|
||||||
Eagle3LlamaForCausalLM.load_weights = make_load_weights(target_config)
|
|
||||||
|
|
||||||
|
|
||||||
def make_load_weights(target_config):
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
quant_cfg = target_config.quant_config
|
|
||||||
rotation_matrix3 = None
|
|
||||||
|
|
||||||
model_path = target_config.model_config.model
|
|
||||||
try:
|
|
||||||
rotation_rel_path = quant_cfg.quant_description["optional"]["quarot"]["rotation_map"]["global_rotation"]
|
|
||||||
except KeyError as e:
|
|
||||||
logger.error(
|
|
||||||
"Invalid quant_config: missing key "
|
|
||||||
"quant_description['optional']['quarot']['rotation_map']['global_rotation']. "
|
|
||||||
"If you don't use quarot model, please ignore it. "
|
|
||||||
f"Error: {e}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
rotation_path = Path(model_path) / rotation_rel_path
|
|
||||||
try:
|
|
||||||
safetensor_data = load_file(rotation_path)
|
|
||||||
Q = safetensor_data["global_rotation"]
|
|
||||||
rotation_matrix3 = torch.block_diag(Q, Q, Q)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to load rotation weight from '{rotation_path}'. "
|
|
||||||
"If you don't use quarot model, please ignore it. "
|
|
||||||
f"Error: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
||||||
model_weights = {}
|
|
||||||
includes_draft_id_mapping = False
|
|
||||||
includes_embed_tokens = False
|
|
||||||
for name, loaded_weight in weights:
|
|
||||||
if "t2d" in name:
|
|
||||||
continue
|
|
||||||
if "d2t" in name:
|
|
||||||
name = name.replace("d2t", "draft_id_to_target_id")
|
|
||||||
includes_draft_id_mapping = True
|
|
||||||
elif "lm_head" not in name:
|
|
||||||
name = "model." + name
|
|
||||||
if "fc." in name and rotation_matrix3 is not None:
|
|
||||||
loaded_weight = loaded_weight @ rotation_matrix3.to(loaded_weight.dtype)
|
|
||||||
if "embed_tokens" in name:
|
|
||||||
includes_embed_tokens = True
|
|
||||||
model_weights[name] = loaded_weight
|
|
||||||
process_eagle_weight(self, name)
|
|
||||||
|
|
||||||
skip_substrs = []
|
|
||||||
if not includes_draft_id_mapping:
|
|
||||||
skip_substrs.append("draft_id_to_target_id")
|
|
||||||
if not includes_embed_tokens:
|
|
||||||
skip_substrs.append("embed_tokens")
|
|
||||||
if not self.model.use_aux_hidden_state:
|
|
||||||
skip_substrs.append("fc.")
|
|
||||||
loader = AutoWeightsLoader(
|
|
||||||
self,
|
|
||||||
skip_prefixes=None,
|
|
||||||
skip_substrs=skip_substrs,
|
|
||||||
)
|
|
||||||
loader.load_weights(model_weights.items())
|
|
||||||
|
|
||||||
return load_weights
|
|
||||||
@@ -103,8 +103,8 @@ from vllm_ascend.eplb.core.eplb_worker import EplbProcess
|
|||||||
from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
from vllm_ascend.eplb.eplb_updator import EplbUpdator
|
||||||
from vllm_ascend.eplb.utils import model_register
|
from vllm_ascend.eplb.utils import model_register
|
||||||
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
|
from vllm_ascend.ops.rotary_embedding import set_cos_and_sin, update_cos_sin
|
||||||
|
from vllm_ascend.patch.worker.patch_draft_quarot import patch_load_weights
|
||||||
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
from vllm_ascend.patch.worker.patch_module import patch_torch_npu_argsort
|
||||||
from vllm_ascend.patch.worker.patch_qwen3_quarot import patch_load_weights
|
|
||||||
from vllm_ascend.sample.sampler import AscendSampler
|
from vllm_ascend.sample.sampler import AscendSampler
|
||||||
from vllm_ascend.spec_decode import get_spec_decode_method
|
from vllm_ascend.spec_decode import get_spec_decode_method
|
||||||
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import AscendEagleProposer
|
||||||
|
|||||||
Reference in New Issue
Block a user