This commit is contained in:
@@ -42,22 +42,10 @@ from sglang.srt.layers.moe import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
is_cuda,
|
||||
is_flashinfer_available,
|
||||
is_gfx95_supported,
|
||||
is_hip,
|
||||
is_sm100_supported,
|
||||
)
|
||||
from sglang.srt.utils import is_cuda, is_flashinfer_available, is_sm100_supported
|
||||
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and is_hip()
|
||||
_is_gfx95_supported = is_gfx95_supported()
|
||||
|
||||
if _use_aiter and _is_gfx95_supported:
|
||||
from sglang.srt.layers.quantization.rocm_mxfp4_utils import fused_rms_mxfp4_quant
|
||||
|
||||
FUSE_ALLREDUCE_MAX_BATCH_SIZE = 2048
|
||||
|
||||
@@ -213,7 +201,6 @@ class LayerCommunicator:
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
qaunt_format: str = "",
|
||||
):
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
@@ -231,34 +218,11 @@ class LayerCommunicator:
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
|
||||
hidden_states = fused_rms_mxfp4_quant(
|
||||
hidden_states,
|
||||
self.input_layernorm.weight,
|
||||
self.input_layernorm.variance_epsilon,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
if _use_aiter and _is_gfx95_supported and ("mxfp4" in qaunt_format):
|
||||
hidden_states, residual = fused_rms_mxfp4_quant(
|
||||
hidden_states,
|
||||
self.input_layernorm.weight,
|
||||
self.input_layernorm.variance_epsilon,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
residual,
|
||||
)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
hidden_states = self._communicate_simple_fn(
|
||||
hidden_states=hidden_states,
|
||||
|
||||
Reference in New Issue
Block a user