forked from EngineX-Cambricon/enginex-mlu370-vllm
add deepseekv3 and llama4
This commit is contained in:
@@ -582,15 +582,24 @@ def unified_flash_attention_v2(
|
||||
else:
|
||||
# unpaged (linear cache) path
|
||||
if use_mla:
|
||||
# MLA cache 是 2D (total_slots, head_dim),
|
||||
# 不能用 reshape_paged_cache(期望 4D),直接索引写入
|
||||
# MLA: 镜像 paged 路径的处理方式
|
||||
# key_cache: (num_blocks, 1, block_size, 576)
|
||||
value_to_cache = None
|
||||
if attn_metadata.prefill_metadata:
|
||||
# MLA prefill cache 已在 forward_prefill 中写入,跳过
|
||||
pass
|
||||
else:
|
||||
# key: (num_tokens, 1, head_dim) → squeeze → (num_tokens, head_dim)
|
||||
# key_cache: (total_slots, head_dim)
|
||||
key_cache[updated_slot_mapping.flatten()] = key.squeeze(1)
|
||||
if kv_cache_dtype == 'int8':
|
||||
mlu_ops.quant_to_paged_cache(
|
||||
key, value_to_cache,
|
||||
key_cache, value_cache,
|
||||
key_cache_scale, value_cache_scale,
|
||||
updated_slot_mapping.flatten())
|
||||
else:
|
||||
mlu_ops.reshape_paged_cache(
|
||||
key, value_to_cache,
|
||||
key_cache, value_cache,
|
||||
updated_slot_mapping.flatten())
|
||||
else:
|
||||
# FIXME: After TMO-1496 is completed, remove this code.
|
||||
if key.stride() != value.stride():
|
||||
|
||||
@@ -37,7 +37,7 @@ def vllm__config__CacheConfig___verify_cache_dtype(self) -> None:
|
||||
|
||||
def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type == 'deepseek_v2':
|
||||
if hasattr(self.hf_text_config,"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3'):
|
||||
# feature flag MLA
|
||||
return 1
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
@@ -51,7 +51,7 @@ def vllm__config__ModelConfig__get_num_kv_heads(self, parallel_config: "Parallel
|
||||
def vllm__config__ModelConfig__get_head_size(self) -> int:
|
||||
# TODO remove hard code
|
||||
if hasattr(self.hf_text_config, "model_type"
|
||||
) and self.hf_text_config.model_type == 'deepseek_v2':
|
||||
) and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3'):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
@@ -109,7 +109,7 @@ def vllm__config__LoRAConfig__verify_with_model_config(self, model_config: Model
|
||||
def vllm__config__ModelConfig__is_deepseek_v2(self) -> bool:
|
||||
result = hasattr(
|
||||
self.hf_text_config,
|
||||
"model_type") and self.hf_text_config.model_type == 'deepseek_v2'
|
||||
"model_type") and self.hf_text_config.model_type in ('deepseek_v2', 'deepseek_v3')
|
||||
return result
|
||||
|
||||
MluHijackObject.apply_hijack(ModelConfig,
|
||||
|
||||
@@ -39,3 +39,9 @@ try:
|
||||
except ImportError as e:
|
||||
import logging
|
||||
logging.warning(f"Failed to import mllama hijack: {e}")
|
||||
|
||||
try:
|
||||
import vllm_mlu.model_executor.models.llama4
|
||||
except ImportError as e:
|
||||
import logging
|
||||
logging.warning(f"Failed to import llama4 hijack: {e}")
|
||||
|
||||
485
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/llama4.py
Normal file
485
vllm-v0.6.2/vllm_mlu/vllm_mlu/model_executor/models/llama4.py
Normal file
@@ -0,0 +1,485 @@
|
||||
import torch
|
||||
import re
|
||||
|
||||
from typing import List, Optional, Tuple, Union, Iterable
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm_mlu.model_executor.layers.feed_forward import FeedForward
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.llama4 import (
|
||||
Llama4Attention, Llama4DecoderLayer, Llama4ForCausalLM,
|
||||
Llama4Model, Llama4MoE)
|
||||
from vllm_mlu.model_executor.layers.sparse_moe_mlp import SparseMoeMlp
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_mlu.model_executor.models.layer_utils import (
|
||||
decoder_layer_forward_base, decoder_model_forward_base_pp,
|
||||
is_per_tensor_smoothquant, is_per_token_smoothquant,
|
||||
quant_fusion_with_rmsnorm)
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# ============================================================
|
||||
# Llama4MoE MLU replacement: SparseMoeMlp + shared expert
|
||||
# ============================================================
|
||||
|
||||
class Llama4MoEMlu(SparseMoeMlp):
|
||||
"""MLU replacement for Llama4MoE using SparseMoeMlp + shared expert."""
|
||||
|
||||
def __init__(self, config, quant_config=None, prefix=""):
|
||||
num_local_experts = getattr(config, "num_local_experts", 8)
|
||||
top_k = getattr(config, "num_experts_per_tok", 1)
|
||||
hidden_size = getattr(config, "hidden_size", 4096)
|
||||
intermediate_size = getattr(config, "intermediate_size", 8192)
|
||||
|
||||
super().__init__(
|
||||
num_experts=num_local_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
up_proj_name="gate_up_proj",
|
||||
is_gated=True,
|
||||
down_proj_name="down_proj",
|
||||
has_bias=False,
|
||||
skip_bias_add=False,
|
||||
renormalize=False,
|
||||
hidden_act="silu",
|
||||
params_dtype=None,
|
||||
quant_config=quant_config,
|
||||
is_use_fused_moe=True,
|
||||
)
|
||||
|
||||
# Llama4 uses sigmoid routing, not softmax
|
||||
# Override topk_softmax to use sigmoid
|
||||
self._use_sigmoid_routing = True
|
||||
|
||||
# Shared expert (independent from routed experts)
|
||||
self.shared_expert = FeedForward(
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act="silu",
|
||||
up_proj_name="gate_up_proj",
|
||||
is_gated=True,
|
||||
down_proj_name="down_proj",
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=False,
|
||||
prefix=f"{prefix}.shared_expert",
|
||||
)
|
||||
|
||||
def topk_softmax(self, expert_logits):
|
||||
"""Override: Llama4 uses sigmoid routing instead of softmax."""
|
||||
topk_values, topk_indices = torch.topk(
|
||||
expert_logits, self.top_k, dim=-1)
|
||||
topk_values = torch.sigmoid(topk_values.float())
|
||||
return topk_values, topk_indices
|
||||
|
||||
def forward(self, hidden_states, residual=None):
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
|
||||
# Shared expert output
|
||||
shared_out = self.shared_expert(hidden_states)
|
||||
|
||||
# Router logits
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
# Routed experts
|
||||
routed_out = self.forward_experts(hidden_states, router_logits, None)
|
||||
|
||||
# Combine
|
||||
final_out = routed_out + shared_out
|
||||
if self.tp_size > 1:
|
||||
final_out = tensor_model_parallel_all_reduce(final_out)
|
||||
|
||||
return final_out.view(orig_shape)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Llama4Attention hijack
|
||||
# ============================================================
|
||||
|
||||
vllm__llama4__Llama4Attention__init__org = Llama4Attention.__init__
|
||||
|
||||
|
||||
def vllm__llama4__Llama4Attention____init__(
|
||||
self,
|
||||
config,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
vllm__llama4__Llama4Attention__init__org(
|
||||
self, config, hidden_size, num_heads, num_kv_heads,
|
||||
max_position_embeddings, quant_config, bias, cache_config, prefix)
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: save rope_scaling for MLU RoPE dispatch
|
||||
'''
|
||||
self.rope_scaling = getattr(config, "rope_scaling", None)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
|
||||
def vllm__llama4__Llama4Attention__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
smooth_quant_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states, smooth_quant_scale)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: MLU RoPE: merge q/k, apply rotary, split back (教训 #3)
|
||||
For NoPE layers (self.rotary_emb is None), skip RoPE entirely.
|
||||
'''
|
||||
if self.rotary_emb is not None:
|
||||
if (self.rope_scaling is not None
|
||||
and self.rope_scaling.get("rope_type") == "longrope"):
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
else:
|
||||
qk, _ = qkv.split(
|
||||
[self.q_size + self.kv_size, self.kv_size], dim=-1)
|
||||
self.rotary_emb(
|
||||
positions,
|
||||
qk.view(-1, self.num_heads + self.num_kv_heads,
|
||||
self.head_dim))
|
||||
q, k = qk.split([self.q_size, self.kv_size], dim=-1)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
# QK norm (教训 #2: use contiguous + reshape)
|
||||
if self.qk_norm is not None:
|
||||
q = q.contiguous().reshape(-1, self.head_dim)
|
||||
q = (self.qk_norm(q.float())
|
||||
.contiguous().reshape(-1, self.q_size).to(q.dtype))
|
||||
k = k.contiguous().reshape(-1, self.head_dim)
|
||||
k = (self.qk_norm(k.float())
|
||||
.contiguous().reshape(-1, self.kv_size).to(k.dtype))
|
||||
|
||||
# Temperature tuning for NoPE layers
|
||||
if self.attn_temperature_tuning and self.nope:
|
||||
attn_scale = self._get_attn_scale(positions)
|
||||
q = (q * attn_scale).to(q.dtype)
|
||||
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: add residual in o_proj
|
||||
'''
|
||||
output, _ = self.o_proj(attn_output, residual)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
return output
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Llama4DecoderLayer hijack
|
||||
# ============================================================
|
||||
|
||||
def vllm__llama4__Llama4DecoderLayer____init__(
|
||||
self,
|
||||
config,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super(Llama4DecoderLayer, self).__init__()
|
||||
from vllm.model_executor.models.llama4 import (
|
||||
_extract_layer_index, Llama4Attention)
|
||||
|
||||
self.layer_idx = _extract_layer_index(prefix)
|
||||
self.hidden_size = getattr(config, "hidden_size", 4096)
|
||||
max_position_embeddings = getattr(
|
||||
config, "max_position_embeddings", 8192)
|
||||
|
||||
self.self_attn = Llama4Attention(
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=getattr(config, "num_attention_heads", 32),
|
||||
num_kv_heads=getattr(config, "num_key_value_heads",
|
||||
getattr(config, "num_attention_heads", 32)),
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=False,
|
||||
cache_config=cache_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
|
||||
interleave_moe_layer_step = getattr(
|
||||
config, "interleave_moe_layer_step", 0)
|
||||
is_moe_layer = (interleave_moe_layer_step > 0
|
||||
and (self.layer_idx + 1)
|
||||
% interleave_moe_layer_step == 0)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: Replace MoE with Llama4MoEMlu (SparseMoeMlp + shared expert),
|
||||
Replace dense MLP with FeedForward.
|
||||
'''
|
||||
if is_moe_layer:
|
||||
self.feed_forward = Llama4MoEMlu(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
else:
|
||||
intermediate_size_mlp = getattr(
|
||||
config, "intermediate_size_mlp",
|
||||
getattr(config, "intermediate_size", 8192))
|
||||
self.feed_forward = FeedForward(
|
||||
hidden_size=self.hidden_size,
|
||||
intermediate_size=intermediate_size_mlp,
|
||||
hidden_act="silu",
|
||||
up_proj_name="gate_up_proj",
|
||||
is_gated=True,
|
||||
down_proj_name="down_proj",
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.feed_forward",
|
||||
)
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
rms_norm_eps = getattr(config, "rms_norm_eps", 1e-5)
|
||||
self.input_layernorm = RMSNorm(self.hidden_size, eps=rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
self.hidden_size, eps=rms_norm_eps)
|
||||
|
||||
self.is_per_tesnor_sq_perf_cases = is_per_tensor_smoothquant(
|
||||
quant_config)
|
||||
self.is_per_token_sq_perf_cases = is_per_token_smoothquant(
|
||||
quant_config)
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
self.self_attn.qkv_proj.quant_method.skip_quant_input = True
|
||||
self.quant_fusion_attn_layernorm = None
|
||||
|
||||
|
||||
def vllm__llama4__Llama4DecoderLayer__forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: use decoder_layer_forward_base with residual-in-matmul
|
||||
and optional quant fusion.
|
||||
'''
|
||||
attn_layernorm = self.input_layernorm
|
||||
if self.is_per_tesnor_sq_perf_cases or self.is_per_token_sq_perf_cases:
|
||||
if self.quant_fusion_attn_layernorm is None:
|
||||
if self.is_per_token_sq_perf_cases:
|
||||
attn_quant_scale = self.self_attn.qkv_proj.smooth
|
||||
else:
|
||||
attn_quant_scale = self.self_attn.qkv_proj.scale_to_int
|
||||
self.quant_fusion_attn_layernorm = quant_fusion_with_rmsnorm(
|
||||
self.input_layernorm, attn_quant_scale,
|
||||
dynamic_quant=self.is_per_token_sq_perf_cases)
|
||||
attn_layernorm = self.quant_fusion_attn_layernorm
|
||||
|
||||
return decoder_layer_forward_base(
|
||||
positions, hidden_states, kv_cache, attn_metadata,
|
||||
attn_layernorm,
|
||||
self.self_attn,
|
||||
self.post_attention_layernorm,
|
||||
self.feed_forward,
|
||||
input_norm_fuse_en=self.is_per_token_sq_perf_cases)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Llama4Model hijack
|
||||
# ============================================================
|
||||
|
||||
def vllm__llama4__Llama4Model__forward(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
return decoder_model_forward_base_pp(
|
||||
input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors,
|
||||
self.layers, self.start_layer, self.end_layer,
|
||||
self.get_input_embeddings,
|
||||
self.norm,
|
||||
inputs_embeds)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Llama4ForCausalLM load_weights hijack
|
||||
# ============================================================
|
||||
|
||||
def vllm__llama4__Llama4ForCausalLM__load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
):
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: pack params for SparseMoeMlp (MoE layers)
|
||||
'''
|
||||
for name, m in self.model.named_modules():
|
||||
if isinstance(m, SparseMoeMlp):
|
||||
m.pack_params()
|
||||
|
||||
start_expert_id = 0
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
# Permute Q/K weights for rotary embedding
|
||||
name, loaded_weight = self.permute_qk_weight_for_rotary(
|
||||
name, loaded_weight)
|
||||
|
||||
'''
|
||||
=============================
|
||||
Modify by vllm_mlu
|
||||
=============================
|
||||
@brief: remap expert_id for distributed inference
|
||||
'''
|
||||
if (start_expert_id > 0
|
||||
and "feed_forward.experts." in name):
|
||||
match = re.search(r'experts\.\d+', name)
|
||||
if match:
|
||||
expert_str = match.group(0)
|
||||
expert_id = int(expert_str.split(".")[1])
|
||||
named_expert_id = expert_id - start_expert_id
|
||||
name = name.replace(
|
||||
f"experts.{expert_id}",
|
||||
f"experts.{named_expert_id}")
|
||||
'''
|
||||
==================
|
||||
End of MLU Hijack
|
||||
==================
|
||||
'''
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
# Skip experts not assigned to this worker
|
||||
if ("feed_forward.experts." in name
|
||||
and name not in params_dict):
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
if ((name.endswith(".bias") or name.endswith("_bias"))
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if is_pp_missing_parameter(name, self):
|
||||
continue
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
# Skip experts not assigned to this worker
|
||||
if ("feed_forward.experts." in name
|
||||
and name not in params_dict):
|
||||
continue
|
||||
if name not in params_dict:
|
||||
logger.warning(
|
||||
"Skipping weight %s not present in the model", name)
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Apply all hijacks
|
||||
# ============================================================
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
Llama4Attention,
|
||||
Llama4Attention.__init__,
|
||||
vllm__llama4__Llama4Attention____init__)
|
||||
MluHijackObject.apply_hijack(
|
||||
Llama4Attention,
|
||||
Llama4Attention.forward,
|
||||
vllm__llama4__Llama4Attention__forward)
|
||||
MluHijackObject.apply_hijack(
|
||||
Llama4DecoderLayer,
|
||||
Llama4DecoderLayer.__init__,
|
||||
vllm__llama4__Llama4DecoderLayer____init__)
|
||||
MluHijackObject.apply_hijack(
|
||||
Llama4DecoderLayer,
|
||||
Llama4DecoderLayer.forward,
|
||||
vllm__llama4__Llama4DecoderLayer__forward)
|
||||
MluHijackObject.apply_hijack(
|
||||
Llama4Model,
|
||||
Llama4Model.forward,
|
||||
vllm__llama4__Llama4Model__forward)
|
||||
MluHijackObject.apply_hijack(
|
||||
Llama4ForCausalLM,
|
||||
Llama4ForCausalLM.load_weights,
|
||||
vllm__llama4__Llama4ForCausalLM__load_weights)
|
||||
Reference in New Issue
Block a user