Fix illegal memory in trtllm allreduce fusion (#7864)

This commit is contained in:
Xiaoyu Zhang
2025-07-09 02:47:17 +08:00
committed by GitHub
parent 51ae40306a
commit 2e7ab862e3
3 changed files with 8 additions and 6 deletions

View File

@@ -402,12 +402,14 @@ class CommunicateWithAllReduceAndLayerNormFn:
if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states)
else:
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
if (
_is_sm100_supported
and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 1024
and hidden_states.shape[0] <= 128
):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual