[BugFix] Fix combination of MTP and --n-share-experts-fusionwith R1 (#5707)
This commit is contained in:
@@ -13,12 +13,14 @@
|
||||
# ==============================================================================
|
||||
|
||||
"""Inference-only DeepSeek NextN Speculative Decoding."""
|
||||
import logging
|
||||
from typing import Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import ReplicatedLinear
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
@@ -51,6 +53,9 @@ else:
|
||||
from vllm._custom_ops import awq_dequantize
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DeepseekModelNextN(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -134,7 +139,9 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.quant_config = quant_config
|
||||
self.determine_n_share_experts_fusion("DeepseekV3ForCausalLMNextN")
|
||||
|
||||
self.model = DeepseekModelNextN(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
@@ -182,6 +189,48 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
if self.n_share_experts_fusion > 0:
|
||||
logger.info(
|
||||
f"Cloning {self.n_share_experts_fusion} "
|
||||
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
|
||||
)
|
||||
weights_list = list(weights)
|
||||
weights_dict = dict(weights_list)
|
||||
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
|
||||
suffix_list = [
|
||||
"down_proj.weight",
|
||||
"down_proj.weight_scale",
|
||||
"gate_proj.weight",
|
||||
"gate_proj.weight_scale",
|
||||
"up_proj.weight",
|
||||
"up_proj.weight_scale",
|
||||
]
|
||||
else:
|
||||
suffix_list = [
|
||||
"down_proj.weight",
|
||||
"down_proj.weight_scale_inv",
|
||||
"gate_proj.weight",
|
||||
"gate_proj.weight_scale_inv",
|
||||
"up_proj.weight",
|
||||
"up_proj.weight_scale_inv",
|
||||
]
|
||||
names_to_remove = []
|
||||
for num_repeat in range(self.n_share_experts_fusion):
|
||||
for suffix in suffix_list:
|
||||
shared_expert_weight_name = (
|
||||
f"model.layers.0.mlp.shared_experts.{suffix}"
|
||||
)
|
||||
weights_list.append(
|
||||
(
|
||||
f"model.layers.0."
|
||||
f"mlp.experts."
|
||||
f"{self.config.n_routed_experts + num_repeat}"
|
||||
f".{suffix}",
|
||||
weights_dict[shared_expert_weight_name],
|
||||
)
|
||||
)
|
||||
names_to_remove += [shared_expert_weight_name]
|
||||
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
@@ -190,7 +239,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.n_routed_experts,
|
||||
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
||||
)
|
||||
|
||||
nextn_layer_prefix = "model.layers.0"
|
||||
|
||||
@@ -1440,11 +1440,27 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
self.config = config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.quant_config = quant_config
|
||||
self.determine_n_share_experts_fusion()
|
||||
self.model = DeepseekV2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.dp_size = get_attention_dp_size()
|
||||
|
||||
def determine_n_share_experts_fusion(
|
||||
self, architecture: str = "DeepseekV3ForCausalLM"
|
||||
):
|
||||
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||
if self.n_share_experts_fusion > 0:
|
||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||
if (
|
||||
self.config.architectures[0] != "DeepseekV3ForCausalLM"
|
||||
self.config.architectures[0] != architecture
|
||||
or self.config.n_routed_experts != 256
|
||||
):
|
||||
self.n_share_experts_fusion = 0
|
||||
@@ -1459,7 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
elif self.n_share_experts_fusion == 0:
|
||||
if (
|
||||
torch.cuda.get_device_capability("cuda") >= (9, 0)
|
||||
and self.config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
and self.config.architectures[0] == architecture
|
||||
and self.config.n_routed_experts == 256
|
||||
and (not global_server_args_dict["enable_deepep_moe"])
|
||||
):
|
||||
@@ -1469,18 +1485,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
|
||||
)
|
||||
|
||||
self.model = DeepseekV2Model(
|
||||
config, quant_config, prefix=add_prefix("model", prefix)
|
||||
)
|
||||
self.lm_head = ParallelLMHead(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("lm_head", prefix),
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
self.dp_size = get_attention_dp_size()
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.model.embed_tokens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user