forked from EngineX-Cambricon/enginex-mlu370-vllm
add deepseekv3 and llama4
This commit is contained in:
@@ -446,6 +446,12 @@ class Llama4ForCausalLM(nn.Module, SupportsPP):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
# Llama4ForConditionalGeneration uses top-level Llama4Config
|
||||
# which has text_config sub-config. Extract it for text model.
|
||||
text_config = getattr(config, "text_config", None)
|
||||
if text_config is not None:
|
||||
vllm_config.model_config.hf_config = text_config
|
||||
config = text_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
@@ -553,8 +559,19 @@ class Llama4ForCausalLM(nn.Module, SupportsPP):
|
||||
if getattr(self.config, "tie_word_embeddings", False)
|
||||
else None),
|
||||
)
|
||||
weights = [
|
||||
self.permute_qk_weight_for_rotary(name, loaded_weight)
|
||||
for name, loaded_weight in weights
|
||||
]
|
||||
loader.load_weights(weights)
|
||||
|
||||
def _process_weights(weights):
|
||||
for name, loaded_weight in weights:
|
||||
# Strip language_model. prefix for Llama4ForConditionalGeneration
|
||||
if name.startswith("language_model."):
|
||||
name = name[len("language_model."):]
|
||||
# Skip vision encoder weights
|
||||
elif name.startswith("multi_modal_projector.") or \
|
||||
name.startswith("vision_encoder.") or \
|
||||
name.startswith("vision_model."):
|
||||
continue
|
||||
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
||||
name, loaded_weight)
|
||||
yield name, loaded_weight
|
||||
|
||||
loader.load_weights(_process_weights(weights))
|
||||
|
||||
@@ -389,6 +389,15 @@ def vllm__llama4__Llama4ForCausalLM__load_weights(
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
# Strip language_model. prefix for Llama4ForConditionalGeneration
|
||||
if name.startswith("language_model."):
|
||||
name = name[len("language_model."):]
|
||||
# Skip vision encoder weights
|
||||
elif (name.startswith("multi_modal_projector.")
|
||||
or name.startswith("vision_encoder.")
|
||||
or name.startswith("vision_model.")):
|
||||
continue
|
||||
|
||||
# Permute Q/K weights for rotary embedding
|
||||
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
||||
name, loaded_weight)
|
||||
|
||||
Reference in New Issue
Block a user