Fix illegal memory in trtllm allreduce fusion (#7864)
This commit is contained in:
@@ -402,12 +402,14 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = layernorm(hidden_states)
|
||||
else:
|
||||
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
||||
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
||||
if (
|
||||
_is_sm100_supported
|
||||
and _is_flashinfer_available
|
||||
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
||||
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
||||
and hidden_states.shape[0] <= 1024
|
||||
and hidden_states.shape[0] <= 128
|
||||
):
|
||||
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
||||
hidden_states, residual
|
||||
|
||||
Reference in New Issue
Block a user