fuse allreduce and residual_rmsnorm (#8731)
This commit is contained in:
@@ -441,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
||||
and _is_flashinfer_available
|
||||
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
||||
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
||||
and hidden_states.shape[0] <= 128
|
||||
and hidden_states.shape[0] <= 2048
|
||||
):
|
||||
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
||||
hidden_states, residual
|
||||
|
||||
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
|
||||
|
||||
|
||||
def ensure_workspace_initialized(
|
||||
max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
||||
max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
||||
):
|
||||
"""Ensure workspace is initialized"""
|
||||
if not is_flashinfer_available() or _flashinfer_comm is None:
|
||||
@@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm(
|
||||
residual: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
max_token_num: int = 128,
|
||||
use_oneshot: bool = True,
|
||||
max_token_num: int = 2048,
|
||||
use_oneshot: Optional[bool] = None,
|
||||
trigger_completion_at_end: bool = False,
|
||||
fp32_acc: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
@@ -1294,6 +1294,7 @@ class RowParallelLinear(LinearBase):
|
||||
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||
sm.tag(output_parallel)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
|
||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
|
||||
@@ -847,10 +847,14 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
sm.tag(final_hidden_states)
|
||||
|
||||
final_hidden_states = final_hidden_states[
|
||||
..., :origin_hidden_states_dim
|
||||
].contiguous()
|
||||
|
||||
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states[..., :origin_hidden_states_dim].contiguous()
|
||||
return final_hidden_states
|
||||
|
||||
@classmethod
|
||||
def make_expert_params_mapping(
|
||||
|
||||
@@ -212,7 +212,7 @@ class DeepseekV2MLP(nn.Module):
|
||||
self,
|
||||
x,
|
||||
forward_batch=None,
|
||||
can_fuse_mlp_allreduce: bool = False,
|
||||
should_allreduce_fusion: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
):
|
||||
if (self.tp_size == 1) and x.shape[0] == 0:
|
||||
@@ -221,7 +221,7 @@ class DeepseekV2MLP(nn.Module):
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(
|
||||
x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
|
||||
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
|
||||
)
|
||||
return x
|
||||
|
||||
@@ -448,7 +448,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: Optional[ForwardBatch] = None,
|
||||
can_fuse_mlp_allreduce: bool = False,
|
||||
should_allreduce_fusion: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if not self._enable_deepep_moe:
|
||||
@@ -459,11 +459,11 @@ class DeepseekV2MoE(nn.Module):
|
||||
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||
):
|
||||
return self.forward_normal_dual_stream(
|
||||
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
|
||||
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
||||
)
|
||||
else:
|
||||
return self.forward_normal(
|
||||
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
|
||||
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
||||
)
|
||||
else:
|
||||
return self.forward_deepep(hidden_states, forward_batch)
|
||||
@@ -471,7 +471,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
def forward_normal_dual_stream(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
can_fuse_mlp_allreduce: bool = False,
|
||||
should_allreduce_fusion: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -500,20 +500,20 @@ class DeepseekV2MoE(nn.Module):
|
||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||
final_hidden_states = final_hidden_states_out
|
||||
sm.tag(final_hidden_states)
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
|
||||
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
def forward_normal(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
can_fuse_mlp_allreduce: bool = False,
|
||||
should_allreduce_fusion: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||
self.shared_experts.gate_up_proj
|
||||
):
|
||||
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
|
||||
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
||||
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
@@ -537,12 +537,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||
final_hidden_states = final_hidden_states_out
|
||||
sm.tag(final_hidden_states)
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
|
||||
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
def forward_cpu(
|
||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
should_allreduce_fusion: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
@@ -593,7 +595,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
None, # a2_scale
|
||||
True, # is_vnni
|
||||
)
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||
if self.tp_size > 1 and not should_allreduce_fusion:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states
|
||||
|
||||
@@ -1842,6 +1844,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
allow_reduce_scatter=True,
|
||||
)
|
||||
|
||||
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
||||
|
||||
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
||||
return is_nextn or (
|
||||
self.config.n_routed_experts is not None
|
||||
@@ -1850,27 +1854,18 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
||||
"""Check if MLP allreduce can be fused with next layer's add_rmsnorm"""
|
||||
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
||||
|
||||
if (
|
||||
self.layer_id == self.config.num_hidden_layers - 1
|
||||
or get_tensor_model_parallel_world_size() <= 1
|
||||
):
|
||||
batch_size = (
|
||||
forward_batch.input_ids.shape[0]
|
||||
if hasattr(forward_batch, "input_ids")
|
||||
else 0
|
||||
)
|
||||
|
||||
if batch_size > 128:
|
||||
return False
|
||||
|
||||
if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
|
||||
return False
|
||||
|
||||
if not _is_sm100_supported or not _is_flashinfer_available:
|
||||
return False
|
||||
|
||||
if hasattr(forward_batch, "input_ids") and (
|
||||
forward_batch.input_ids.shape[0] == 0
|
||||
or forward_batch.input_ids.shape[0] > 128
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1896,7 +1891,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
can_fuse_mlp_allreduce = (
|
||||
should_allreduce_fusion = (
|
||||
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
||||
and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
|
||||
and not self.is_nextn
|
||||
@@ -1907,13 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
forward_batch
|
||||
)
|
||||
hidden_states = self.mlp(
|
||||
hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
|
||||
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
|
||||
)
|
||||
|
||||
if can_fuse_mlp_allreduce:
|
||||
if should_allreduce_fusion:
|
||||
hidden_states._sglang_needs_allreduce_fusion = True
|
||||
|
||||
if not can_fuse_mlp_allreduce:
|
||||
if not should_allreduce_fusion:
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
@@ -1990,6 +1985,26 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
)
|
||||
return output
|
||||
|
||||
def _build_fuse_allreduce_lookup_table(self):
|
||||
static_conditions_met = (
|
||||
self.layer_id != self.config.num_hidden_layers - 1
|
||||
and get_tensor_model_parallel_world_size() > 1
|
||||
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
||||
and _is_sm100_supported
|
||||
and _is_flashinfer_available
|
||||
)
|
||||
|
||||
if not static_conditions_met:
|
||||
return {}
|
||||
|
||||
lookup_table = {}
|
||||
for batch_size in range(129): # 0 to 128
|
||||
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
||||
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
||||
lookup_table[batch_size] = should_fuse
|
||||
|
||||
return lookup_table
|
||||
|
||||
|
||||
class DeepseekV2Model(nn.Module):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
@@ -154,13 +154,13 @@ class Glm4MoeMLP(nn.Module):
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
|
||||
def forward(self, x, forward_batch=None, should_allreduce_fusion=False):
|
||||
if (self.tp_size == 1) and x.shape[0] == 0:
|
||||
return x
|
||||
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce)
|
||||
x, _ = self.down_proj(x, skip_all_reduce=should_allreduce_fusion)
|
||||
return x
|
||||
|
||||
|
||||
@@ -529,7 +529,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
def forward_normal_dual_stream(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
can_fuse_mlp_allreduce: bool = False,
|
||||
should_allreduce_fusion: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
@@ -553,7 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
if self.ep_size > 1:
|
||||
if (
|
||||
self.tp_size > 1
|
||||
and not can_fuse_mlp_allreduce
|
||||
and not should_allreduce_fusion
|
||||
and not use_reduce_scatter
|
||||
):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
@@ -564,7 +564,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
final_hidden_states += shared_output
|
||||
if (
|
||||
self.tp_size > 1
|
||||
and not can_fuse_mlp_allreduce
|
||||
and not should_allreduce_fusion
|
||||
and not use_reduce_scatter
|
||||
):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
@@ -575,13 +575,13 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
def forward_normal(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
can_fuse_mlp_allreduce: bool = False,
|
||||
should_allreduce_fusion: bool = False,
|
||||
use_reduce_scatter: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||
self.shared_experts.gate_up_proj
|
||||
):
|
||||
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
|
||||
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
||||
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
@@ -596,7 +596,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
# fused in biased_grouped_topk so we can skip here
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
if self.ep_size > 1:
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||
if self.tp_size > 1 and not should_allreduce_fusion:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states
|
||||
)
|
||||
@@ -605,7 +605,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
||||
else:
|
||||
if shared_output is not None:
|
||||
final_hidden_states += shared_output
|
||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||
if self.tp_size > 1 and not should_allreduce_fusion:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states
|
||||
)
|
||||
|
||||
@@ -56,7 +56,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
||||
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead,
|
||||
VocabParallelEmbedding,
|
||||
@@ -64,7 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.utils import add_prefix, make_layers
|
||||
from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers
|
||||
|
||||
_is_flashinfer_available = is_flashinfer_available()
|
||||
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||
|
||||
|
||||
class GptOssConfig(PretrainedConfig):
|
||||
@@ -151,10 +154,13 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: Optional[ForwardBatch] = None,
|
||||
should_allreduce_fusion: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||
return self.forward_normal(hidden_states)
|
||||
return self.forward_normal(hidden_states, should_allreduce_fusion)
|
||||
else:
|
||||
raise Exception("forward_deepep branch not implemented yet")
|
||||
|
||||
@@ -165,7 +171,11 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
if name not in ["correction_bias"]
|
||||
]
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward_normal(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
should_allreduce_fusion: bool = False,
|
||||
) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
@@ -179,7 +189,7 @@ class GptOssSparseMoeBlock(nn.Module):
|
||||
kwargs["topk_output"] = (self.top_k, router_logits)
|
||||
final_hidden_states = self.experts(**kwargs)
|
||||
|
||||
if self.tp_size > 1:
|
||||
if self.tp_size > 1 and not should_allreduce_fusion:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
ans = final_hidden_states.view(num_tokens, hidden_dim)
|
||||
@@ -370,6 +380,7 @@ class GptOssDecoderLayer(nn.Module):
|
||||
|
||||
# GptOss all layers are sparse and have no nextn now
|
||||
self.is_layer_sparse = True
|
||||
self.is_nextn = False
|
||||
is_previous_layer_sparse = True
|
||||
|
||||
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||
@@ -402,6 +413,42 @@ class GptOssDecoderLayer(nn.Module):
|
||||
post_attention_layernorm=self.post_attention_layernorm,
|
||||
)
|
||||
|
||||
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
||||
|
||||
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
||||
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
||||
|
||||
batch_size = (
|
||||
forward_batch.input_ids.shape[0]
|
||||
if hasattr(forward_batch, "input_ids")
|
||||
else 0
|
||||
)
|
||||
|
||||
if batch_size > 128:
|
||||
return False
|
||||
|
||||
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
||||
|
||||
def _build_fuse_allreduce_lookup_table(self):
|
||||
static_conditions_met = (
|
||||
self.layer_id != self.config.num_hidden_layers - 1
|
||||
and get_tensor_model_parallel_world_size() > 1
|
||||
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
||||
and _is_sm100_supported
|
||||
and _is_flashinfer_available
|
||||
)
|
||||
|
||||
if not static_conditions_met:
|
||||
return {}
|
||||
|
||||
lookup_table = {}
|
||||
for batch_size in range(129): # 0 to 128
|
||||
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
||||
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
||||
lookup_table[batch_size] = should_fuse
|
||||
|
||||
return lookup_table
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@@ -424,12 +471,21 @@ class GptOssDecoderLayer(nn.Module):
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
hidden_states = self.mlp(hidden_states, forward_batch)
|
||||
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
hidden_states, residual, forward_batch
|
||||
should_allreduce_fusion = (
|
||||
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
||||
and not self.is_nextn
|
||||
)
|
||||
|
||||
hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
|
||||
|
||||
if should_allreduce_fusion:
|
||||
hidden_states._sglang_needs_allreduce_fusion = True
|
||||
|
||||
if not should_allreduce_fusion:
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
|
||||
@@ -1435,7 +1435,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-allreduce-fusion",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
|
||||
help="Enable FlashInfer allreduce fusion with Residual RMSNorm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--deepep-mode",
|
||||
|
||||
Reference in New Issue
Block a user