add deepseekv3 and llama4

This commit is contained in:
Chranos
2026-02-11 14:30:01 +08:00
parent 8ac7afcbd3
commit 96ed925486
2 changed files with 31 additions and 5 deletions

View File

@@ -446,6 +446,12 @@ class Llama4ForCausalLM(nn.Module, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
# Llama4ForConditionalGeneration uses top-level Llama4Config
# which has text_config sub-config. Extract it for text model.
text_config = getattr(config, "text_config", None)
if text_config is not None:
vllm_config.model_config.hf_config = text_config
config = text_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
@@ -553,8 +559,19 @@ class Llama4ForCausalLM(nn.Module, SupportsPP):
if getattr(self.config, "tie_word_embeddings", False)
else None),
)
weights = [
self.permute_qk_weight_for_rotary(name, loaded_weight)
for name, loaded_weight in weights
]
loader.load_weights(weights)
def _process_weights(weights):
for name, loaded_weight in weights:
# Strip language_model. prefix for Llama4ForConditionalGeneration
if name.startswith("language_model."):
name = name[len("language_model."):]
# Skip vision encoder weights
elif name.startswith("multi_modal_projector.") or \
name.startswith("vision_encoder.") or \
name.startswith("vision_model."):
continue
name, loaded_weight = self.permute_qk_weight_for_rotary(
name, loaded_weight)
yield name, loaded_weight
loader.load_weights(_process_weights(weights))

View File

@@ -389,6 +389,15 @@ def vllm__llama4__Llama4ForCausalLM__load_weights(
if "rotary_emb.inv_freq" in name:
continue
# Strip language_model. prefix for Llama4ForConditionalGeneration
if name.startswith("language_model."):
name = name[len("language_model."):]
# Skip vision encoder weights
elif (name.startswith("multi_modal_projector.")
or name.startswith("vision_encoder.")
or name.startswith("vision_model.")):
continue
# Permute Q/K weights for rotary embedding
name, loaded_weight = self.permute_qk_weight_for_rotary(
name, loaded_weight)