Fix correction bias undefined behavior for nvfp4 models (#10426)
This commit is contained in:
@@ -65,6 +65,7 @@ from sglang.srt.layers.moe import (
|
|||||||
get_deepep_mode,
|
get_deepep_mode,
|
||||||
get_moe_a2a_backend,
|
get_moe_a2a_backend,
|
||||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||||
|
should_use_flashinfer_trtllm_moe,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||||
@@ -375,7 +376,8 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
correction_bias = self.gate.e_score_correction_bias
|
correction_bias = self.gate.e_score_correction_bias
|
||||||
if _is_fp4_quantization_enabled():
|
# https://github.com/sgl-project/sglang/pull/9834#discussion_r2324480643
|
||||||
|
if _is_fp4_quantization_enabled() and should_use_flashinfer_trtllm_moe():
|
||||||
correction_bias = correction_bias.to(torch.bfloat16)
|
correction_bias = correction_bias.to(torch.bfloat16)
|
||||||
self.topk = TopK(
|
self.topk = TopK(
|
||||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||||
|
|||||||
@@ -385,6 +385,8 @@ std::vector<at::Tensor> moe_fused_gate(
|
|||||||
int64_t num_fused_shared_experts,
|
int64_t num_fused_shared_experts,
|
||||||
double routed_scaling_factor,
|
double routed_scaling_factor,
|
||||||
bool apply_routed_scaling_factor_on_output) {
|
bool apply_routed_scaling_factor_on_output) {
|
||||||
|
TORCH_CHECK(input.dtype() == bias.dtype(), "input and bias should have the same dtype");
|
||||||
|
|
||||||
int64_t num_rows = input.size(0);
|
int64_t num_rows = input.size(0);
|
||||||
int32_t num_experts = input.size(1);
|
int32_t num_experts = input.size(1);
|
||||||
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
|
||||||
|
|||||||
Reference in New Issue
Block a user