Fix illegal memory in trtllm allreduce fusion (#7864)
This commit is contained in:
@@ -402,12 +402,14 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|||||||
if hidden_states.shape[0] != 0:
|
if hidden_states.shape[0] != 0:
|
||||||
hidden_states = layernorm(hidden_states)
|
hidden_states = layernorm(hidden_states)
|
||||||
else:
|
else:
|
||||||
|
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
||||||
|
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
||||||
if (
|
if (
|
||||||
_is_sm100_supported
|
_is_sm100_supported
|
||||||
and _is_flashinfer_available
|
and _is_flashinfer_available
|
||||||
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
||||||
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
||||||
and hidden_states.shape[0] <= 1024
|
and hidden_states.shape[0] <= 128
|
||||||
):
|
):
|
||||||
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
||||||
hidden_states, residual
|
hidden_states, residual
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
|
|||||||
|
|
||||||
|
|
||||||
def ensure_workspace_initialized(
|
def ensure_workspace_initialized(
|
||||||
max_token_num: int = 1024, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
||||||
):
|
):
|
||||||
"""Ensure workspace is initialized"""
|
"""Ensure workspace is initialized"""
|
||||||
if not is_flashinfer_available() or _flashinfer_comm is None:
|
if not is_flashinfer_available() or _flashinfer_comm is None:
|
||||||
@@ -119,12 +119,12 @@ def ensure_workspace_initialized(
|
|||||||
return _workspace_manager.initialized
|
return _workspace_manager.initialized
|
||||||
|
|
||||||
|
|
||||||
def flashinfer_allreduce_add_rmsnorm(
|
def flashinfer_allreduce_residual_rmsnorm(
|
||||||
input_tensor: torch.Tensor,
|
input_tensor: torch.Tensor,
|
||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
max_token_num: int = 1024,
|
max_token_num: int = 128,
|
||||||
use_oneshot: bool = True,
|
use_oneshot: bool = True,
|
||||||
trigger_completion_at_end: bool = False,
|
trigger_completion_at_end: bool = False,
|
||||||
fp32_acc: bool = False,
|
fp32_acc: bool = False,
|
||||||
|
|||||||
@@ -174,11 +174,11 @@ class RMSNorm(CustomOp):
|
|||||||
if residual is not None:
|
if residual is not None:
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.layers.flashinfer_comm_fusion import (
|
from sglang.srt.layers.flashinfer_comm_fusion import (
|
||||||
flashinfer_allreduce_add_rmsnorm,
|
flashinfer_allreduce_residual_rmsnorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
if get_tensor_model_parallel_world_size() > 1:
|
if get_tensor_model_parallel_world_size() > 1:
|
||||||
fused_result = flashinfer_allreduce_add_rmsnorm(
|
fused_result = flashinfer_allreduce_residual_rmsnorm(
|
||||||
input_tensor=x,
|
input_tensor=x,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
weight=self.weight,
|
weight=self.weight,
|
||||||
|
|||||||
Reference in New Issue
Block a user