[BugFix] Fix combination of MTP and --n-share-experts-fusionwith R1 (#5707)
This commit is contained in:
@@ -13,12 +13,14 @@
|
|||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|
||||||
"""Inference-only DeepSeek NextN Speculative Decoding."""
|
"""Inference-only DeepSeek NextN Speculative Decoding."""
|
||||||
|
import logging
|
||||||
from typing import Iterable, Optional, Tuple
|
from typing import Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import ReplicatedLinear
|
from sglang.srt.layers.linear import ReplicatedLinear
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
@@ -51,6 +53,9 @@ else:
|
|||||||
from vllm._custom_ops import awq_dequantize
|
from vllm._custom_ops import awq_dequantize
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DeepseekModelNextN(nn.Module):
|
class DeepseekModelNextN(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -134,7 +139,9 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
) -> None:
|
) -> None:
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.determine_n_share_experts_fusion("DeepseekV3ForCausalLMNextN")
|
||||||
|
|
||||||
self.model = DeepseekModelNextN(
|
self.model = DeepseekModelNextN(
|
||||||
config, quant_config, prefix=add_prefix("model", prefix)
|
config, quant_config, prefix=add_prefix("model", prefix)
|
||||||
@@ -182,6 +189,48 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
("gate_up_proj", "gate_proj", 0),
|
("gate_up_proj", "gate_proj", 0),
|
||||||
("gate_up_proj", "up_proj", 1),
|
("gate_up_proj", "up_proj", 1),
|
||||||
]
|
]
|
||||||
|
if self.n_share_experts_fusion > 0:
|
||||||
|
logger.info(
|
||||||
|
f"Cloning {self.n_share_experts_fusion} "
|
||||||
|
"replicas of the shared expert into MoE for DeepseekV3ForCausalLMNextN"
|
||||||
|
)
|
||||||
|
weights_list = list(weights)
|
||||||
|
weights_dict = dict(weights_list)
|
||||||
|
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
|
||||||
|
suffix_list = [
|
||||||
|
"down_proj.weight",
|
||||||
|
"down_proj.weight_scale",
|
||||||
|
"gate_proj.weight",
|
||||||
|
"gate_proj.weight_scale",
|
||||||
|
"up_proj.weight",
|
||||||
|
"up_proj.weight_scale",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
suffix_list = [
|
||||||
|
"down_proj.weight",
|
||||||
|
"down_proj.weight_scale_inv",
|
||||||
|
"gate_proj.weight",
|
||||||
|
"gate_proj.weight_scale_inv",
|
||||||
|
"up_proj.weight",
|
||||||
|
"up_proj.weight_scale_inv",
|
||||||
|
]
|
||||||
|
names_to_remove = []
|
||||||
|
for num_repeat in range(self.n_share_experts_fusion):
|
||||||
|
for suffix in suffix_list:
|
||||||
|
shared_expert_weight_name = (
|
||||||
|
f"model.layers.0.mlp.shared_experts.{suffix}"
|
||||||
|
)
|
||||||
|
weights_list.append(
|
||||||
|
(
|
||||||
|
f"model.layers.0."
|
||||||
|
f"mlp.experts."
|
||||||
|
f"{self.config.n_routed_experts + num_repeat}"
|
||||||
|
f".{suffix}",
|
||||||
|
weights_dict[shared_expert_weight_name],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
names_to_remove += [shared_expert_weight_name]
|
||||||
|
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
||||||
|
|
||||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||||
# (param_name, weight_name, expert_id, shard_id)
|
# (param_name, weight_name, expert_id, shard_id)
|
||||||
@@ -190,7 +239,7 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
ckpt_gate_proj_name="gate_proj",
|
ckpt_gate_proj_name="gate_proj",
|
||||||
ckpt_down_proj_name="down_proj",
|
ckpt_down_proj_name="down_proj",
|
||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
num_experts=self.config.n_routed_experts,
|
num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
|
||||||
)
|
)
|
||||||
|
|
||||||
nextn_layer_prefix = "model.layers.0"
|
nextn_layer_prefix = "model.layers.0"
|
||||||
|
|||||||
@@ -1440,11 +1440,27 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
self.determine_n_share_experts_fusion()
|
||||||
|
self.model = DeepseekV2Model(
|
||||||
|
config, quant_config, prefix=add_prefix("model", prefix)
|
||||||
|
)
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=add_prefix("lm_head", prefix),
|
||||||
|
)
|
||||||
|
self.logits_processor = LogitsProcessor(config)
|
||||||
|
self.dp_size = get_attention_dp_size()
|
||||||
|
|
||||||
|
def determine_n_share_experts_fusion(
|
||||||
|
self, architecture: str = "DeepseekV3ForCausalLM"
|
||||||
|
):
|
||||||
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
||||||
if self.n_share_experts_fusion > 0:
|
if self.n_share_experts_fusion > 0:
|
||||||
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
||||||
if (
|
if (
|
||||||
self.config.architectures[0] != "DeepseekV3ForCausalLM"
|
self.config.architectures[0] != architecture
|
||||||
or self.config.n_routed_experts != 256
|
or self.config.n_routed_experts != 256
|
||||||
):
|
):
|
||||||
self.n_share_experts_fusion = 0
|
self.n_share_experts_fusion = 0
|
||||||
@@ -1459,7 +1475,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
elif self.n_share_experts_fusion == 0:
|
elif self.n_share_experts_fusion == 0:
|
||||||
if (
|
if (
|
||||||
torch.cuda.get_device_capability("cuda") >= (9, 0)
|
torch.cuda.get_device_capability("cuda") >= (9, 0)
|
||||||
and self.config.architectures[0] == "DeepseekV3ForCausalLM"
|
and self.config.architectures[0] == architecture
|
||||||
and self.config.n_routed_experts == 256
|
and self.config.n_routed_experts == 256
|
||||||
and (not global_server_args_dict["enable_deepep_moe"])
|
and (not global_server_args_dict["enable_deepep_moe"])
|
||||||
):
|
):
|
||||||
@@ -1469,18 +1485,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
|
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model = DeepseekV2Model(
|
|
||||||
config, quant_config, prefix=add_prefix("model", prefix)
|
|
||||||
)
|
|
||||||
self.lm_head = ParallelLMHead(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("lm_head", prefix),
|
|
||||||
)
|
|
||||||
self.logits_processor = LogitsProcessor(config)
|
|
||||||
self.dp_size = get_attention_dp_size()
|
|
||||||
|
|
||||||
def get_input_embeddings(self) -> nn.Embedding:
|
def get_input_embeddings(self) -> nn.Embedding:
|
||||||
return self.model.embed_tokens
|
return self.model.embed_tokens
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user