Optimized deepseek-v3/r1 model performance on mxfp4 run (#10008)

Co-authored-by: wunhuang <wunhuang@amd.com>
Co-authored-by: HAI <hixiao@gmail.com>
Co-authored-by: Hubert Lu <55214931+hubertlu-tw@users.noreply.github.com>
This commit is contained in:
kk
2025-09-05 06:11:22 +08:00
committed by GitHub
parent 93088b6975
commit e96973742c
8 changed files with 486 additions and 64 deletions

View File

@@ -43,8 +43,11 @@ from sglang.srt.layers.moe import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import (
get_bool_env_var,
is_cuda,
is_flashinfer_available,
is_gfx95_supported,
is_hip,
is_sm90_supported,
is_sm100_supported,
)
@@ -52,6 +55,11 @@ from sglang.srt.utils import (
_is_flashinfer_available = is_flashinfer_available()
_is_sm90_supported = is_cuda() and is_sm90_supported()
_is_sm100_supported = is_cuda() and is_sm100_supported()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
_is_gfx95_supported = is_gfx95_supported()
if _use_aiter and _is_gfx95_supported:
from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
@@ -207,6 +215,7 @@ class LayerCommunicator:
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
qaunt_format: str = "",
):
if hidden_states.shape[0] == 0:
residual = hidden_states
@@ -224,11 +233,34 @@ class LayerCommunicator:
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
hidden_states = fused_rms_mxfp4_quant(
hidden_states,
self.input_layernorm.weight,
self.input_layernorm.variance_epsilon,
None,
None,
None,
None,
)
else:
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual
)
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
hidden_states, residual = fused_rms_mxfp4_quant(
hidden_states,
self.input_layernorm.weight,
self.input_layernorm.variance_epsilon,
None,
None,
None,
residual,
)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual
)
hidden_states = self._communicate_simple_fn(
hidden_states=hidden_states,