[b200] support trt-llm allreduce fuse rms_norm_add kernel (#7621)
This commit is contained in:
@@ -32,8 +32,13 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
)
|
||||
from sglang.srt.layers.utils import is_sm100_supported
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.utils import is_cuda, is_flashinfer_available
|
||||
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
|
||||
|
||||
class ScatterMode(Enum):
|
||||
@@ -397,8 +402,19 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
hidden_states, residual = layernorm(hidden_states, residual)
|
||||
if (
|
||||
_is_sm100_supported
|
||||
and _is_flashinfer_available
|
||||
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
||||
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
||||
and hidden_states.shape[0] <= 1024
|
||||
):
|
||||
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
||||
hidden_states, residual
|
||||
)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
hidden_states, residual = layernorm(hidden_states, residual)
|
||||
return hidden_states, residual
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user