perf: Avoid unnecessary data type conversions for DeepSeek-V3 on Blackwell (#9834)
Signed-off-by: Jinyang Yuan <154768711+jinyangyuan-nvidia@users.noreply.github.com>
This commit is contained in:
@@ -655,7 +655,8 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|||||||
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
||||||
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
||||||
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
|
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
|
||||||
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
|
||||||
|
os.environ["TRTLLM_ENABLE_PDL"] = "1"
|
||||||
|
|
||||||
# Can also be passed as argument
|
# Can also be passed as argument
|
||||||
os.environ["SGLANG_RUN_ID"] = (
|
os.environ["SGLANG_RUN_ID"] = (
|
||||||
|
|||||||
@@ -67,7 +67,10 @@ from sglang.srt.layers.moe import (
|
|||||||
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||||
|
FusedMoE,
|
||||||
|
_is_fp4_quantization_enabled,
|
||||||
|
)
|
||||||
from sglang.srt.layers.moe.topk import TopK
|
from sglang.srt.layers.moe.topk import TopK
|
||||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
@@ -299,7 +302,9 @@ class MoEGate(nn.Module):
|
|||||||
and _device_sm >= 90
|
and _device_sm >= 90
|
||||||
):
|
):
|
||||||
# router gemm output float32
|
# router gemm output float32
|
||||||
logits = dsv3_router_gemm(hidden_states, self.weight)
|
logits = dsv3_router_gemm(
|
||||||
|
hidden_states, self.weight, out_dtype=torch.float32
|
||||||
|
)
|
||||||
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
|
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
|
||||||
logits = aiter_dsv3_router_gemm(
|
logits = aiter_dsv3_router_gemm(
|
||||||
hidden_states, self.weight, gemm_output_zero_allocator
|
hidden_states, self.weight, gemm_output_zero_allocator
|
||||||
@@ -364,6 +369,9 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
prefix=add_prefix("experts", prefix),
|
prefix=add_prefix("experts", prefix),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
correction_bias = self.gate.e_score_correction_bias
|
||||||
|
if _is_fp4_quantization_enabled():
|
||||||
|
correction_bias = correction_bias.to(torch.bfloat16)
|
||||||
self.topk = TopK(
|
self.topk = TopK(
|
||||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||||
renormalize=config.norm_topk_prob,
|
renormalize=config.norm_topk_prob,
|
||||||
@@ -371,7 +379,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
num_expert_group=config.n_group,
|
num_expert_group=config.n_group,
|
||||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||||
topk_group=config.topk_group,
|
topk_group=config.topk_group,
|
||||||
correction_bias=self.gate.e_score_correction_bias,
|
correction_bias=correction_bias,
|
||||||
routed_scaling_factor=self.routed_scaling_factor,
|
routed_scaling_factor=self.routed_scaling_factor,
|
||||||
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
||||||
force_topk=quant_config is None,
|
force_topk=quant_config is None,
|
||||||
|
|||||||
Reference in New Issue
Block a user