Upgrade to vllm 0.17.0 corex v4.1 overlay

This commit is contained in:
2026-04-29 19:38:22 +08:00
parent 8fac6062e4
commit 938d0854a5
430 changed files with 35969 additions and 14511 deletions

View File

@@ -47,7 +47,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import GateLinear, SharedFusedMoE
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
@@ -75,7 +75,9 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionBackend
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend,
)
@@ -89,6 +91,7 @@ from .utils import (
make_layers,
maybe_prefix,
)
import ixformer.inference.functions as ixfops
logger = init_logger(__name__)
@@ -221,73 +224,6 @@ class DeepseekV2MLP(nn.Module):
return x
class DeepSeekV2Gate(ReplicatedLinear):
def __init__(
self,
hidden_size: int,
n_experts: int,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
assert quant_config is None
super().__init__(
hidden_size,
n_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate",
)
# Unquantized only, will be called "weight".
assert hasattr(self, "weight")
is_hopper_or_blackwell = current_platform.is_device_capability(
(9, 0)
) or current_platform.is_device_capability_family(100)
SUPPORTED_NUM_EXPERTS = [256, 384]
SUPPORTED_HIDDEN_SIZES = [7168]
self.allow_dsv3_router_gemm = (
current_platform.is_cuda()
and is_hopper_or_blackwell
and n_experts in SUPPORTED_NUM_EXPERTS
and hidden_size in SUPPORTED_HIDDEN_SIZES
)
self._out_dtype: torch.dtype | None = None
def set_out_dtype(self, out_dtype: torch.dtype) -> None:
"""
Set out dtype for the router logits. This is needed after
__init__, b/c we need to check if the trtllm kernel is
selected before we decide between bf16 and fp32.
"""
if self._out_dtype is not None:
raise ValueError("out_dtype has already been set")
else:
self._out_dtype = out_dtype
@property
def out_dtype(self) -> torch.dtype:
if self._out_dtype is None:
raise ValueError("out_dtype has not been set yet")
return self._out_dtype
def forward(
self,
x: torch.Tensor,
) -> tuple[torch.Tensor, None]:
"""
Use specialized GEMM for low batch size for DSV3 and KIMI.
"""
if self.allow_dsv3_router_gemm and x.shape[0] <= 16:
return ops.dsv3_router_gemm(
hidden_states=x, router_weight=self.weight, output_dtype=self.out_dtype
), None
else:
return super().forward(x)
class DeepseekV2MoE(nn.Module):
def __init__(
self,
@@ -316,23 +252,12 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now."
)
# self.gate = DeepSeekV2Gate(
# config.hidden_size,
# config.n_routed_experts,
# quant_config=None,
# prefix=f"{prefix}.gate",
# )
self.gate = ReplicatedLinear(
self.gate = GateLinear(
config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)
if getattr(config, "topk_method", None) == "noaux_tc":
# self.gate.e_score_correction_bias = nn.Parameter(
# torch.empty(config.n_routed_experts, dtype=torch.float32)
# )
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts)
)
@@ -401,12 +326,12 @@ class DeepseekV2MoE(nn.Module):
else None,
)
# # NOTE(rob): this is a hack until we finish off the PR for
# # merging TRTLLM kernels into the MK framework. Then we can
# # query the MonolithicMK for the expected router logits.
# self.gate.set_out_dtype(
# torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
# )
# NOTE(rob): this is a hack until we finish off the PR for
# merging TRTLLM kernels into the MK framework. Then we can
# query the MonolithicMK for the expected router logits.
self.gate.set_out_dtype(
torch.float32 if self.experts.quant_method.is_monolithic else torch.bfloat16
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
@@ -443,11 +368,12 @@ class DeepseekV2MoE(nn.Module):
elif self.shared_experts is not None:
assert shared_output is not None
shared_output *= 1.0 / self.routed_scaling_factor
if self.shared_experts is not None:
assert shared_output is not None
final_hidden_states += shared_output
if self.is_sequence_parallel:
final_hidden_states = tensor_model_parallel_all_gather(
final_hidden_states, 0
@@ -596,7 +522,7 @@ class DeepseekV2Attention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
@@ -605,23 +531,20 @@ class DeepseekV2Attention(nn.Module):
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
q = self.q_a_layernorm(q)
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
else:
q = self.q_proj(hidden_states)[0].view(
-1, self.num_local_heads, self.qk_head_dim
)
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
latent_cache = latent_cache.unsqueeze(1)
q = self.q_proj(hidden_states)[0]
kv_a, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split([self.kv_lora_rank, self.qk_rope_head_dim], dim=1)
q = q.view(-1, self.num_local_heads, self.qk_head_dim)
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
kv_a = self.kv_a_layernorm(kv_a)
kv = self.kv_b_proj(kv_a)[0]
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k_pe = latent_cache[:, :, self.kv_lora_rank :]
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
k_nope, v_nope = kv.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q[..., self.qk_nope_head_dim :] = q_pe
k = torch.empty_like(q)
@@ -671,7 +594,7 @@ class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase):
def get_attn_backend(self) -> AttentionBackend:
return DeepseekV32IndexerBackend
class Indexer(nn.Module):
def __init__(
@@ -727,8 +650,8 @@ class Indexer(nn.Module):
# where we store value in fp8 and scale in fp32
# per self.quant_block_size element
self.k_cache = DeepseekV32IndexerCache(
head_dim=self.head_dim + self.head_dim // self.quant_block_size * 4,
dtype=torch.uint8,
head_dim=self.head_dim,
dtype=torch.bfloat16,
prefix=f"{prefix}.k_cache",
cache_config=cache_config,
)
@@ -776,23 +699,61 @@ class Indexer(nn.Module):
k = torch.cat([k_pe.squeeze(-2), k_nope], dim=-1)
# we only quant q here since k quant is fused with cache insertion
q = q.view(-1, self.head_dim)
q_fp8, q_scale = per_token_group_quant_fp8(
q,
self.quant_block_size,
column_major_scales=False,
use_ue8m0=self.scale_fmt is not None,
)
q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
q_scale = q_scale.view(-1, self.n_head, 1)
# q = q.view(-1, self.head_dim)
# q_fp8, q_scale = per_token_group_quant_fp8(
# q,
# self.quant_block_size,
# column_major_scales=False,
# use_ue8m0=self.scale_fmt is not None,
# )
# q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim)
# q_scale = q_scale.view(-1, self.n_head, 1)
weights, _ = self.weights_proj(hidden_states)
weights = (
weights.unsqueeze(-1) * q_scale * self.softmax_scale * self.n_head**-0.5
weights.unsqueeze(-1) * self.softmax_scale * self.n_head**-0.5
)
weights = weights.squeeze(-1)
return self.indexer_op(hidden_states, q_fp8, k, weights)
return self.indexer_op(hidden_states, q, k, weights)
def _min_latency_fused_qkv_a_proj_impl(
input_: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
"""
Dynamically run min-latency gemm if num_tokens <= 16.
This must be wrapped in a custom op because our torch.compile integration
does not support runtime dispatching on num_tokens.
"""
num_tokens = input_.shape[0]
if 0 < num_tokens <= 16:
output = torch.empty(
num_tokens,
weight.shape[0],
dtype=torch.bfloat16,
device=input_.device,
)
ops.dsv3_fused_a_gemm(output, input_, weight.T)
return output
else:
return torch.nn.functional.linear(input_, weight)
def _min_latency_fused_qkv_a_proj_fake(
input_: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return input_.new_empty(input_.shape[0], weight.shape[0])
direct_register_custom_op(
op_name="min_latency_fused_qkv_a_proj",
op_func=_min_latency_fused_qkv_a_proj_impl,
mutates_args=[],
fake_impl=_min_latency_fused_qkv_a_proj_fake,
)
class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
@@ -830,19 +791,8 @@ class DeepSeekV2FusedQkvAProj(MergedColumnParallelLinear):
self,
input_,
) -> torch.Tensor | tuple[torch.Tensor, torch.nn.Parameter | None]:
num_tokens = input_.shape[0]
if self._use_min_latency_gemm and (0 < num_tokens <= 16):
output = torch.empty(
num_tokens,
2112,
dtype=torch.bfloat16,
device=input_.device,
)
ops.dsv3_fused_a_gemm(
output,
input_,
self.weight.T,
)
if self._use_min_latency_gemm:
output = torch.ops.vllm.min_latency_fused_qkv_a_proj(input_, self.weight)
if not self.return_bias:
return output
output_bias = self.bias if self.skip_bias_add else None
@@ -898,47 +848,35 @@ class DeepseekV2MLAAttention(nn.Module):
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
# self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProj(
# self.hidden_size,
# [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
# quant_config=quant_config,
# prefix=f"{prefix}.fused_qkv_a_proj",
# )
self.fused_qkv_a_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj",
disable_tp=True,
)
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(self.q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
)
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
if self.q_lora_rank is not None:
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
self.q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
@@ -1005,9 +943,7 @@ class DeepseekV2MLAAttention(nn.Module):
kv_b_proj=self.kv_b_proj,
rotary_emb=self.rotary_emb,
o_proj=self.o_proj,
fused_qkv_a_proj=self.fused_qkv_a_proj
if self.q_lora_rank is not None
else None,
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa
if self.q_lora_rank is None
else None,
@@ -1346,14 +1282,14 @@ class DeepseekV2ForCausalLM(
# initializing DeepseekV2Model, as it is passed inplace to
# quantization config init and may be used to select the
# quant_method for relevant layers during initialization.
self.fuse_qkv_a_proj = (
hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
)
if self.fuse_qkv_a_proj:
self.packed_modules_mapping["fused_qkv_a_proj"] = [
"q_a_proj",
"kv_a_proj_with_mqa",
]
# self.fuse_qkv_a_proj = (
# hasattr(config, "q_lora_rank") and config.q_lora_rank is not None
# )
# if self.fuse_qkv_a_proj:
# self.packed_modules_mapping["fused_qkv_a_proj"] = [
# "q_a_proj",
# "kv_a_proj_with_mqa",
# ]
self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
@@ -1385,19 +1321,19 @@ class DeepseekV2ForCausalLM(
self.moe_layers = []
self.moe_mlp_layers = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_mlp_layers.append(layer.mlp)
self.moe_layers.append(layer.mlp.experts)
self.extract_moe_parameters(example_moe)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
@@ -1441,10 +1377,10 @@ class DeepseekV2ForCausalLM(
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
mla_params_mapping = [
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
# mla_params_mapping = [
# ("fused_qkv_a_proj", "q_a_proj", 0),
# ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
# ]
mha_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
@@ -1452,8 +1388,8 @@ class DeepseekV2ForCausalLM(
]
if self.use_mha:
stacked_params_mapping.extend(mha_params_mapping)
else:
stacked_params_mapping.extend(mla_params_mapping)
# else:
# stacked_params_mapping.extend(mla_params_mapping)
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
@@ -1474,168 +1410,232 @@ class DeepseekV2ForCausalLM(
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
is_fusion_moe_shared_experts_layer = (
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
if is_fusion_moe_shared_experts_layer:
continue
name_mapped = name.replace(weight_name, param_name)
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
# if go with fusion option, then update name
if (
param_name == "fused_qkv_a_proj"
) and name_mapped not in params_dict:
continue
else:
name = name_mapped
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
try:
if "rotary_emb.inv_freq" in name:
continue
if is_pp_missing_parameter(name, self):
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
# Special handling: when AITER fusion_shared_experts is enabled,
# checkpoints may provide a single widened shared_experts tensor
# without explicit expert indices
# (e.g. ...mlp.shared_experts.gate_proj.weight).
# For models with multiple shared experts, split that tensor
# evenly into per-shared-expert slices and load them into
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
if is_fusion_moe_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1
split_dim = (
1
if ("down_proj.weight" in name and loaded_weight.ndim > 1)
else 0
)
total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, (
f"Shared expert weight dim {total} "
f"not divisible by num_chunks {num_chunks}"
)
chunk_size = total // num_chunks
for j in range(num_chunks):
chunk_name = name
weight_to_load = loaded_weight
is_fusion_moe_shared_experts_layer = (
rocm_aiter_moe_shared_expert_enabled and ("mlp.shared_experts" in name)
)
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if ("mlp.experts." in name) and name not in params_dict:
continue
if is_fusion_moe_shared_experts_layer:
chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
if loaded_weight.ndim == 1:
weight_to_load = loaded_weight[chunk_slice]
elif split_dim == 0:
weight_to_load = loaded_weight[chunk_slice, :]
else:
weight_to_load = loaded_weight[:, chunk_slice]
# Synthesize an expert-style name so expert mapping
# can route it
chunk_name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts + j}",
)
continue
name_mapped = name.replace(weight_name, param_name)
# Use expert_params_mapping to locate the destination
# param and delegate to its expert-aware weight_loader
# with expert_id.
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in chunk_name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
break
# QKV fusion is optional, fall back to normal
# weight loading if it's not enabled
# if go with fusion option, then update name
if (
param_name == "fused_qkv_a_proj"
) and name_mapped not in params_dict:
continue
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
name = name_mapped
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
is_expert_weight = False
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
# Special handling: when AITER fusion_shared_experts is enabled,
# checkpoints may provide a single widened shared_experts tensor
# without explicit expert indices
# (e.g. ...mlp.shared_experts.gate_proj.weight).
# For models with multiple shared experts, split that tensor
# evenly into per-shared-expert slices and load them into
# appended expert slots mlp.experts.{n_routed_experts + j}.*
# accordingly.
num_chunks = 1
if is_fusion_moe_shared_experts_layer:
num_chunks = getattr(self.config, "n_shared_experts", 1) or 1
# Determine split axis based on op type
# gate/up: ColumnParallel → split along dim 0
# down: RowParallel → split along dim 1
split_dim = (
1
if ("down_proj.weight" in name and loaded_weight.ndim > 1)
else 0
)
weight_loader(param, loaded_weight)
if name is not None and not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
total = loaded_weight.shape[split_dim]
assert total % num_chunks == 0, (
f"Shared expert weight dim {total} "
f"not divisible by num_chunks {num_chunks}"
)
chunk_size = total // num_chunks
for j in range(num_chunks):
chunk_name = name
weight_to_load = loaded_weight
if is_fusion_moe_shared_experts_layer:
chunk_slice = slice(j * chunk_size, (j + 1) * chunk_size)
if loaded_weight.ndim == 1:
weight_to_load = loaded_weight[chunk_slice]
elif split_dim == 0:
weight_to_load = loaded_weight[chunk_slice, :]
else:
weight_to_load = loaded_weight[:, chunk_slice]
# Synthesize an expert-style name so expert mapping
# can route it
chunk_name = name.replace(
"mlp.shared_experts",
f"mlp.experts.{self.config.n_routed_experts + j}",
)
# Use expert_params_mapping to locate the destination
# param and delegate to its expert-aware weight_loader
# with expert_id.
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in chunk_name:
continue
# Anyway, this is an expert weight and should not be
# attempted to load as other weights later
is_expert_weight = True
# Do not modify `name` since the loop may continue here
# Instead, create a new variable
name_mapped = chunk_name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or
# not here since otherwise we may skip experts with
# other available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
weight_to_load,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
if not is_fusion_moe_shared_experts_layer:
name = name_mapped
else:
loaded_params.add(name_mapped)
break
else:
if is_expert_weight:
# We've checked that this is an expert weight
# However it's not mapped locally to this rank
# So we simply skip it
continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if name is not None and not is_fusion_moe_shared_experts_layer:
loaded_params.add(name)
except:
pass
opt_support_quant_method = ["GGUFLinearMethod", "UnquantizedLinearMethod", "CompressedTensorsW8A8Int8", "AWQMarlinLinearMethod"]
# add your opt here..
def inject_layer(layer, quant_method, is_mla):
q_lora_rank = getattr(layer, "q_lora_rank", None)
if quant_method in ["UnquantizedLinearMethod", "CompressedTensorsW8A8Int8"]:
if q_lora_rank is not None:
layer.q_a_proj.weight.data = torch.cat([layer.q_a_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
if hasattr(layer.q_a_proj, "weight_scale"):
layer.q_a_proj.weight_scale.data = torch.cat([layer.q_a_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
del layer.kv_a_proj_with_mqa.weight_scale
elif not is_mla:
layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=0)
if hasattr(layer.q_proj, "weight_scale"):
layer.q_proj.weight_scale.data = torch.cat([layer.q_proj.weight_scale, layer.kv_a_proj_with_mqa.weight_scale], dim=0)
del layer.kv_a_proj_with_mqa.weight_scale
else:
return
del layer.kv_a_proj_with_mqa.weight
del layer.kv_a_proj_with_mqa
if is_mla:
layer.mla_attn.forward = layer.mla_attn.forward_opt
else:
layer.forward = layer.forward_opt
elif quant_method == "GGUFLinearMethod":
pass
elif quant_method == "AWQMarlinLinearMethod":
dtype = layer.kv_a_proj_with_mqa.qweight.dtype
assert dtype == torch.int32
if layer.q_lora_rank is not None:
layer.q_a_proj.qweight.data = torch.cat([layer.q_a_proj.qweight, layer.kv_a_proj_with_mqa.qweight], dim=1)
layer.q_a_proj.scales.data = torch.cat([layer.q_a_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
del layer.kv_a_proj_with_mqa.scales
layer.q_a_proj.qzeros.data = torch.cat([layer.q_a_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
del layer.kv_a_proj_with_mqa.qzeros
elif not is_mla:
layer.q_proj.weight.data = torch.cat([layer.q_proj.weight, layer.kv_a_proj_with_mqa.weight], dim=1)
layer.q_proj.scales.data = torch.cat([layer.q_proj.scales, layer.kv_a_proj_with_mqa.scales], dim=1)
del layer.kv_a_proj_with_mqa.scales
layer.q_proj.qzeros.data = torch.cat([layer.q_proj.qzeros, layer.kv_a_proj_with_mqa.qzeros], dim=1)
del layer.kv_a_proj_with_mqa.qzeros
else:
return
del layer.kv_a_proj_with_mqa.qweight
del layer.kv_a_proj_with_mqa
if is_mla:
layer.mla_attn.forward = layer.mla_attn.forward_opt
else:
layer.forward = layer.forward_opt
else:
pass
for _, layer in self.model.named_modules():
if layer.__class__.__name__ in ["DeepseekV2Attention","DeepseekV2MLAAttention"]:
if hasattr(layer.kv_a_proj_with_mqa, "scheme"):
quant_method = layer.kv_a_proj_with_mqa.scheme.__class__.__name__
else:
quant_method = layer.kv_a_proj_with_mqa.quant_method.__class__.__name__
if quant_method not in opt_support_quant_method:
break
inject_layer(layer, quant_method, is_mla = layer.__class__.__name__ == "DeepseekV2MLAAttention")
return loaded_params