Optimized deepseek-v3/r1 model performance on mxfp4 run (#9671)

Co-authored-by: wunhuang <wunhuang@amd.com>
Co-authored-by: wghuang <wghuang@amd.com>
This commit is contained in:
kk
2025-09-03 13:26:28 +08:00
committed by GitHub
parent bcbeed714f
commit 0dfd54d11d
7 changed files with 455 additions and 59 deletions

View File

@@ -42,10 +42,22 @@ from sglang.srt.layers.moe import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
from sglang.srt.utils import (
get_bool_env_var,
is_cuda,
is_flashinfer_available,
is_gfx95_supported,
is_hip,
is_sm100_supported,
)
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
_is_gfx95_supported = is_gfx95_supported()
if _use_aiter and _is_gfx95_supported:
from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
@@ -201,6 +213,7 @@ class LayerCommunicator:
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
qaunt_format: str = "",
):
if hidden_states.shape[0] == 0:
residual = hidden_states
@@ -218,11 +231,34 @@ class LayerCommunicator:
else:
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
hidden_states = fused_rms_mxfp4_quant(
hidden_states,
self.input_layernorm.weight,
self.input_layernorm.variance_epsilon,
None,
None,
None,
None,
)
else:
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual
)
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
hidden_states, residual = fused_rms_mxfp4_quant(
hidden_states,
self.input_layernorm.weight,
self.input_layernorm.variance_epsilon,
None,
None,
None,
residual,
)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual
)
hidden_states = self._communicate_simple_fn(
hidden_states=hidden_states,