Revert "Optimized deepseek-v3/r1 model performance on mxfp4 run (#9671)" (#9959)

This commit is contained in:
Yineng Zhang
2025-09-03 00:50:04 -07:00
committed by GitHub
parent 2c7ca33abb
commit 1b2ff4fb7f
7 changed files with 59 additions and 455 deletions

View File

@@ -42,22 +42,10 @@ from sglang.srt.layers.moe import (
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import (
get_bool_env_var,
is_cuda,
is_flashinfer_available,
is_gfx95_supported,
is_hip,
is_sm100_supported,
)
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
_is_gfx95_supported = is_gfx95_supported()
if _use_aiter and _is_gfx95_supported:
from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
@@ -213,7 +201,6 @@ class LayerCommunicator:
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
qaunt_format: str = "",
):
if hidden_states.shape[0] == 0:
residual = hidden_states
@@ -231,34 +218,11 @@ class LayerCommunicator:
else:
if residual is None:
residual = hidden_states
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
hidden_states = fused_rms_mxfp4_quant(
hidden_states,
self.input_layernorm.weight,
self.input_layernorm.variance_epsilon,
None,
None,
None,
None,
)
else:
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.input_layernorm(hidden_states)
else:
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
hidden_states, residual = fused_rms_mxfp4_quant(
hidden_states,
self.input_layernorm.weight,
self.input_layernorm.variance_epsilon,
None,
None,
None,
residual,
)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual
)
hidden_states, residual = self.input_layernorm(
hidden_states, residual
)
hidden_states = self._communicate_simple_fn(
hidden_states=hidden_states,