add deepseekv3 and llama4
This commit is contained in:
@@ -446,6 +446,12 @@ class Llama4ForCausalLM(nn.Module, SupportsPP):
|
|||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
config = vllm_config.model_config.hf_config
|
config = vllm_config.model_config.hf_config
|
||||||
|
# Llama4ForConditionalGeneration uses top-level Llama4Config
|
||||||
|
# which has text_config sub-config. Extract it for text model.
|
||||||
|
text_config = getattr(config, "text_config", None)
|
||||||
|
if text_config is not None:
|
||||||
|
vllm_config.model_config.hf_config = text_config
|
||||||
|
config = text_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
lora_config = vllm_config.lora_config
|
lora_config = vllm_config.lora_config
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -553,8 +559,19 @@ class Llama4ForCausalLM(nn.Module, SupportsPP):
|
|||||||
if getattr(self.config, "tie_word_embeddings", False)
|
if getattr(self.config, "tie_word_embeddings", False)
|
||||||
else None),
|
else None),
|
||||||
)
|
)
|
||||||
weights = [
|
|
||||||
self.permute_qk_weight_for_rotary(name, loaded_weight)
|
def _process_weights(weights):
|
||||||
for name, loaded_weight in weights
|
for name, loaded_weight in weights:
|
||||||
]
|
# Strip language_model. prefix for Llama4ForConditionalGeneration
|
||||||
loader.load_weights(weights)
|
if name.startswith("language_model."):
|
||||||
|
name = name[len("language_model."):]
|
||||||
|
# Skip vision encoder weights
|
||||||
|
elif name.startswith("multi_modal_projector.") or \
|
||||||
|
name.startswith("vision_encoder.") or \
|
||||||
|
name.startswith("vision_model."):
|
||||||
|
continue
|
||||||
|
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
||||||
|
name, loaded_weight)
|
||||||
|
yield name, loaded_weight
|
||||||
|
|
||||||
|
loader.load_weights(_process_weights(weights))
|
||||||
|
|||||||
@@ -389,6 +389,15 @@ def vllm__llama4__Llama4ForCausalLM__load_weights(
|
|||||||
if "rotary_emb.inv_freq" in name:
|
if "rotary_emb.inv_freq" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Strip language_model. prefix for Llama4ForConditionalGeneration
|
||||||
|
if name.startswith("language_model."):
|
||||||
|
name = name[len("language_model."):]
|
||||||
|
# Skip vision encoder weights
|
||||||
|
elif (name.startswith("multi_modal_projector.")
|
||||||
|
or name.startswith("vision_encoder.")
|
||||||
|
or name.startswith("vision_model.")):
|
||||||
|
continue
|
||||||
|
|
||||||
# Permute Q/K weights for rotary embedding
|
# Permute Q/K weights for rotary embedding
|
||||||
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
||||||
name, loaded_weight)
|
name, loaded_weight)
|
||||||
|
|||||||
Reference in New Issue
Block a user