[b200] support trt-llm allreduce fuse rms_norm_add kernel (#7621)

This commit is contained in:
Xiaoyu Zhang
2025-07-03 10:36:20 +08:00
committed by GitHub
parent 82f021e22e
commit 8e64140e35
5 changed files with 253 additions and 2 deletions

View File

@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.layers.utils import is_sm100_supported
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_cuda, is_flashinfer_available
_is_flashinfer_available = is_flashinfer_available()
_is_sm100_supported = is_cuda() and is_sm100_supported()
class ScatterMode(Enum):
@@ -397,8 +402,19 @@ class CommunicateWithAllReduceAndLayerNormFn:
if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = layernorm(hidden_states, residual)
if (
_is_sm100_supported
and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 1024
):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual
)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual
@staticmethod