diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 4b4cdcb34..5e5801fff 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -655,7 +655,8 @@ def _set_envs_and_config(server_args: ServerArgs): os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" os.environ["CUDA_MODULE_LOADING"] = "AUTO" # flashinfer uses this environment variable for various kernels from MoE to quant kernels - os.environ["TRTLLM_ENABLE_PDL"] = "1" + if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0": + os.environ["TRTLLM_ENABLE_PDL"] = "1" # Can also be passed as argument os.environ["SGLANG_RUN_ID"] = ( diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 1a56e87c6..05b5490f8 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -67,7 +67,10 @@ from sglang.srt.layers.moe import ( should_use_flashinfer_cutlass_moe_fp4_allgather, ) from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class -from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.fused_moe_triton.layer import ( + FusedMoE, + _is_fp4_quantization_enabled, +) from sglang.srt.layers.moe.topk import TopK from sglang.srt.layers.quantization import deep_gemm_wrapper from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -299,7 +302,9 @@ class MoEGate(nn.Module): and _device_sm >= 90 ): # router gemm output float32 - logits = dsv3_router_gemm(hidden_states, self.weight) + logits = dsv3_router_gemm( + hidden_states, self.weight, out_dtype=torch.float32 + ) elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256: logits = aiter_dsv3_router_gemm( hidden_states, self.weight, gemm_output_zero_allocator @@ -364,6 +369,9 @@ class DeepseekV2MoE(nn.Module): prefix=add_prefix("experts", prefix), ) + correction_bias = self.gate.e_score_correction_bias + if _is_fp4_quantization_enabled(): + correction_bias = correction_bias.to(torch.bfloat16) self.topk = TopK( top_k=config.num_experts_per_tok + self.num_fused_shared_experts, renormalize=config.norm_topk_prob, @@ -371,7 +379,7 @@ class DeepseekV2MoE(nn.Module): num_expert_group=config.n_group, num_fused_shared_experts=self.num_fused_shared_experts, topk_group=config.topk_group, - correction_bias=self.gate.e_score_correction_bias, + correction_bias=correction_bias, routed_scaling_factor=self.routed_scaling_factor, apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(), force_topk=quant_config is None,