diff --git a/python/sglang/srt/models/deepseek_nextn.py b/python/sglang/srt/models/deepseek_nextn.py index a01f386da..6c8ca0e0d 100644 --- a/python/sglang/srt/models/deepseek_nextn.py +++ b/python/sglang/srt/models/deepseek_nextn.py @@ -25,13 +25,18 @@ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_r from sglang.srt.layers.dp_attention import is_dp_attention_enabled from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization import Fp8Config from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from sglang.srt.model_executor.forward_batch_info import ForwardBatch -from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM +from sglang.srt.models.deepseek_v2 import ( + DeepseekV2DecoderLayer, + DeepseekV3ForCausalLM, + enable_nextn_moe_bf16_cast_to_fp8, +) from sglang.srt.server_args import get_global_server_args from sglang.srt.utils import BumpAllocator, add_prefix, is_cuda @@ -49,6 +54,16 @@ class DeepseekModelNextN(nn.Module): prefix: str = "", ) -> None: super().__init__() + + if enable_nextn_moe_bf16_cast_to_fp8(quant_config): + # refer to real DeepSeek V3 quant config + moe_quant_config = Fp8Config( + is_checkpoint_fp8_serialized=True, + weight_block_size=[128, 128], + ) + else: + moe_quant_config = None + if quant_config is not None and quant_config.get_name() == "modelopt_fp4": logger.warning( "Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model." @@ -74,6 +89,7 @@ class DeepseekModelNextN(nn.Module): config, 0, quant_config=quant_config, + moe_quant_config=moe_quant_config, is_nextn=True, prefix=add_prefix("decoder", prefix), alt_stream=self.alt_stream, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index cdc1beb6f..cd7e8b682 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -26,6 +26,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn +from tqdm import tqdm, trange from transformers import PretrainedConfig from sglang.srt import single_batch_overlap @@ -82,7 +83,7 @@ from sglang.srt.layers.moe import ( from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat -from sglang.srt.layers.quantization import deep_gemm_wrapper +from sglang.srt.layers.quantization import Fp8Config, deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.fp8_kernel import ( is_fp8_fnuz, @@ -196,6 +197,15 @@ _is_cublas_ge_129 = is_nvidia_cublas_cu12_version_ge_12_9() logger = logging.getLogger(__name__) + +def enable_nextn_moe_bf16_cast_to_fp8(quant_config): + return ( + quant_config is not None + and quant_config.get_name() == "modelopt_fp4" + and get_moe_a2a_backend().is_deepep() + ) + + FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [ "fa3", "nsa", @@ -526,6 +536,7 @@ class DeepseekV2MoE(nn.Module): self.config = config self.layer_id = layer_id self.alt_stream = alt_stream + self.is_nextn = is_nextn if self.tp_size > config.n_routed_experts: raise ValueError( @@ -2381,6 +2392,7 @@ class DeepseekV2DecoderLayer(nn.Module): config: PretrainedConfig, layer_id: int, quant_config: Optional[QuantizationConfig] = None, + moe_quant_config: Optional[QuantizationConfig] = None, is_nextn: bool = False, prefix: str = "", alt_stream: Optional[torch.cuda.Stream] = None, @@ -2430,7 +2442,7 @@ class DeepseekV2DecoderLayer(nn.Module): if self.is_layer_sparse: self.mlp = DeepseekV2MoE( config=config, - quant_config=quant_config, + quant_config=moe_quant_config or quant_config, prefix=add_prefix("mlp", prefix), layer_id=self.layer_id, alt_stream=alt_stream, @@ -3109,6 +3121,9 @@ class DeepseekV2ForCausalLM(nn.Module): ): self._weight_requant_ue8m0(is_nextn) + if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): + self._transform_scale_nextn_moe_ue8m0() + def _weight_requant_ue8m0(self, is_nextn=False): weight_block_size = self.quant_config.weight_block_size @@ -3174,6 +3189,28 @@ class DeepseekV2ForCausalLM(nn.Module): module.weight, module.weight_scale_inv, weight_block_size ) + # TODO avoid code dup (currently combine from weight_requant_ue8m0 and transform_scale_ue8m0) + def _transform_scale_nextn_moe_ue8m0(self): + layer = self.model.decoder + + shared_experts = getattr(layer.mlp, "shared_experts", None) + if shared_experts is not None: + for module in [ + shared_experts.gate_up_proj, + shared_experts.down_proj, + ]: + transform_scale_ue8m0_inplace( + module.weight_scale_inv, mn=module.weight.shape[-2] + ) + + experts = layer.mlp.experts + if isinstance(experts, DeepEPMoE): + for w in [ + experts.w13_weight_fp8, + experts.w2_weight_fp8, + ]: + transform_scale_ue8m0_inplace(w[1], mn=w[0].shape[-2]) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): if is_nextn: @@ -3189,6 +3226,11 @@ class DeepseekV2ForCausalLM(nn.Module): else: raise ValueError("num_nextn_predict_layers is not in the config") + if is_nextn and enable_nextn_moe_bf16_cast_to_fp8(self.quant_config): + weights = self._quant_nextn_moe_to_fp8_ue8m0( + weights, nextn_layer_id=nextn_layer_id + ) + stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), @@ -3418,6 +3460,38 @@ class DeepseekV2ForCausalLM(nn.Module): self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names) + # TODO avoid code dup + def _quant_nextn_moe_to_fp8_ue8m0(self, weights, nextn_layer_id: int): + weights_dict = dict(weights) + + # temporarily only support DeepSeek V3/R1 + weight_block_size = [128, 128] + + for layer_id in [nextn_layer_id]: + for expert_sub_name in [ + "shared_experts", + *[ + f"experts.{expert_id}" + for expert_id in range(self.config.n_routed_experts) + ], + ]: + for stem in [ + "gate_proj", + "up_proj", + "down_proj", + ]: + partial_name = ( + f"model.layers.{layer_id}.mlp.{expert_sub_name}.{stem}" + ) + original_weight = weights_dict[f"{partial_name}.weight"] + out_w, out_s = quant_weight_ue8m0( + original_weight, weight_block_size=weight_block_size + ) + weights_dict[f"{partial_name}.weight"] = out_w + weights_dict[f"{partial_name}.weight_scale_inv"] = out_s + + return list(weights_dict.items()) + def get_embed_and_head(self): return self.model.embed_tokens.weight, self.lm_head.weight