Fix illegal memory in trtllm allreduce fusion (#7864)

This commit is contained in:
Xiaoyu Zhang
2025-07-09 02:47:17 +08:00
committed by GitHub
parent 51ae40306a
commit 2e7ab862e3
3 changed files with 8 additions and 6 deletions

View File

@@ -402,12 +402,14 @@ class CommunicateWithAllReduceAndLayerNormFn:
if hidden_states.shape[0] != 0:
hidden_states = layernorm(hidden_states)
else:
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
if (
_is_sm100_supported
and _is_flashinfer_available
and hasattr(layernorm, "forward_with_allreduce_fusion")
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
and hidden_states.shape[0] <= 1024
and hidden_states.shape[0] <= 128
):
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
hidden_states, residual

View File

@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
def ensure_workspace_initialized(
max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False
max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
):
"""Ensure workspace is initialized"""
if not is_flashinfer_available() or _flashinfer_comm is None:
@@ -119,12 +119,12 @@ def ensure_workspace_initialized(
return _workspace_manager.initialized
def flashinfer_allreduce_add_rmsnorm(
def flashinfer_allreduce_residual_rmsnorm(
input_tensor: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
max_token_num: int = 1024,
max_token_num: int = 128,
use_oneshot: bool = True,
trigger_completion_at_end: bool = False,
fp32_acc: bool = False,

View File

@@ -174,11 +174,11 @@ class RMSNorm(CustomOp):
if residual is not None:
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.layers.flashinfer_comm_fusion import (
flashinfer_allreduce_add_rmsnorm,
flashinfer_allreduce_residual_rmsnorm,
)
if get_tensor_model_parallel_world_size() > 1:
fused_result = flashinfer_allreduce_add_rmsnorm(
fused_result = flashinfer_allreduce_residual_rmsnorm(
input_tensor=x,
residual=residual,
weight=self.weight,