fuse allreduce and residual_rmsnorm (#8731)
This commit is contained in:
@@ -441,7 +441,7 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|||||||
and _is_flashinfer_available
|
and _is_flashinfer_available
|
||||||
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
and hasattr(layernorm, "forward_with_allreduce_fusion")
|
||||||
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
and global_server_args_dict["enable_flashinfer_allreduce_fusion"]
|
||||||
and hidden_states.shape[0] <= 128
|
and hidden_states.shape[0] <= 2048
|
||||||
):
|
):
|
||||||
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
hidden_states, residual = layernorm.forward_with_allreduce_fusion(
|
||||||
hidden_states, residual
|
hidden_states, residual
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ _workspace_manager = FlashInferWorkspaceManager()
|
|||||||
|
|
||||||
|
|
||||||
def ensure_workspace_initialized(
|
def ensure_workspace_initialized(
|
||||||
max_token_num: int = 128, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
max_token_num: int = 2048, hidden_dim: int = 4096, use_fp32_lamport: bool = False
|
||||||
):
|
):
|
||||||
"""Ensure workspace is initialized"""
|
"""Ensure workspace is initialized"""
|
||||||
if not is_flashinfer_available() or _flashinfer_comm is None:
|
if not is_flashinfer_available() or _flashinfer_comm is None:
|
||||||
@@ -124,8 +124,8 @@ def flashinfer_allreduce_residual_rmsnorm(
|
|||||||
residual: torch.Tensor,
|
residual: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
eps: float = 1e-6,
|
eps: float = 1e-6,
|
||||||
max_token_num: int = 128,
|
max_token_num: int = 2048,
|
||||||
use_oneshot: bool = True,
|
use_oneshot: Optional[bool] = None,
|
||||||
trigger_completion_at_end: bool = False,
|
trigger_completion_at_end: bool = False,
|
||||||
fp32_acc: bool = False,
|
fp32_acc: bool = False,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
|||||||
@@ -1294,6 +1294,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
||||||
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||||
sm.tag(output_parallel)
|
sm.tag(output_parallel)
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
|
if self.reduce_results and self.tp_size > 1 and not skip_all_reduce:
|
||||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -847,10 +847,14 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
sm.tag(final_hidden_states)
|
sm.tag(final_hidden_states)
|
||||||
|
|
||||||
|
final_hidden_states = final_hidden_states[
|
||||||
|
..., :origin_hidden_states_dim
|
||||||
|
].contiguous()
|
||||||
|
|
||||||
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
return final_hidden_states[..., :origin_hidden_states_dim].contiguous()
|
return final_hidden_states
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_expert_params_mapping(
|
def make_expert_params_mapping(
|
||||||
|
|||||||
@@ -212,7 +212,7 @@ class DeepseekV2MLP(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
forward_batch=None,
|
forward_batch=None,
|
||||||
can_fuse_mlp_allreduce: bool = False,
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
):
|
):
|
||||||
if (self.tp_size == 1) and x.shape[0] == 0:
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
||||||
@@ -221,7 +221,7 @@ class DeepseekV2MLP(nn.Module):
|
|||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.down_proj(
|
x, _ = self.down_proj(
|
||||||
x, skip_all_reduce=can_fuse_mlp_allreduce or use_reduce_scatter
|
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
|
||||||
)
|
)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@@ -448,7 +448,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
forward_batch: Optional[ForwardBatch] = None,
|
forward_batch: Optional[ForwardBatch] = None,
|
||||||
can_fuse_mlp_allreduce: bool = False,
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not self._enable_deepep_moe:
|
if not self._enable_deepep_moe:
|
||||||
@@ -459,11 +459,11 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||||
):
|
):
|
||||||
return self.forward_normal_dual_stream(
|
return self.forward_normal_dual_stream(
|
||||||
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
|
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_normal(
|
return self.forward_normal(
|
||||||
hidden_states, can_fuse_mlp_allreduce, use_reduce_scatter
|
hidden_states, should_allreduce_fusion, use_reduce_scatter
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_deepep(hidden_states, forward_batch)
|
return self.forward_deepep(hidden_states, forward_batch)
|
||||||
@@ -471,7 +471,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
def forward_normal_dual_stream(
|
def forward_normal_dual_stream(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
can_fuse_mlp_allreduce: bool = False,
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
@@ -500,20 +500,20 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||||
final_hidden_states = final_hidden_states_out
|
final_hidden_states = final_hidden_states_out
|
||||||
sm.tag(final_hidden_states)
|
sm.tag(final_hidden_states)
|
||||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
|
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
def forward_normal(
|
def forward_normal(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
can_fuse_mlp_allreduce: bool = False,
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||||
self.shared_experts.gate_up_proj
|
self.shared_experts.gate_up_proj
|
||||||
):
|
):
|
||||||
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
|
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
||||||
|
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
@@ -537,12 +537,14 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
|
||||||
final_hidden_states = final_hidden_states_out
|
final_hidden_states = final_hidden_states_out
|
||||||
sm.tag(final_hidden_states)
|
sm.tag(final_hidden_states)
|
||||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce and not use_reduce_scatter:
|
if self.tp_size > 1 and not should_allreduce_fusion and not use_reduce_scatter:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
def forward_cpu(
|
def forward_cpu(
|
||||||
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
should_allreduce_fusion: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
@@ -593,7 +595,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
None, # a2_scale
|
None, # a2_scale
|
||||||
True, # is_vnni
|
True, # is_vnni
|
||||||
)
|
)
|
||||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
if self.tp_size > 1 and not should_allreduce_fusion:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
@@ -1842,6 +1844,8 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
allow_reduce_scatter=True,
|
allow_reduce_scatter=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
||||||
|
|
||||||
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
||||||
return is_nextn or (
|
return is_nextn or (
|
||||||
self.config.n_routed_experts is not None
|
self.config.n_routed_experts is not None
|
||||||
@@ -1850,27 +1854,18 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
||||||
"""Check if MLP allreduce can be fused with next layer's add_rmsnorm"""
|
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
||||||
|
|
||||||
if (
|
batch_size = (
|
||||||
self.layer_id == self.config.num_hidden_layers - 1
|
forward_batch.input_ids.shape[0]
|
||||||
or get_tensor_model_parallel_world_size() <= 1
|
if hasattr(forward_batch, "input_ids")
|
||||||
):
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch_size > 128:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
|
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
||||||
return False
|
|
||||||
|
|
||||||
if not _is_sm100_supported or not _is_flashinfer_available:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if hasattr(forward_batch, "input_ids") and (
|
|
||||||
forward_batch.input_ids.shape[0] == 0
|
|
||||||
or forward_batch.input_ids.shape[0] > 128
|
|
||||||
):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -1896,7 +1891,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
can_fuse_mlp_allreduce = (
|
should_allreduce_fusion = (
|
||||||
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
||||||
and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
|
and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
|
||||||
and not self.is_nextn
|
and not self.is_nextn
|
||||||
@@ -1907,13 +1902,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
forward_batch
|
forward_batch
|
||||||
)
|
)
|
||||||
hidden_states = self.mlp(
|
hidden_states = self.mlp(
|
||||||
hidden_states, forward_batch, can_fuse_mlp_allreduce, use_reduce_scatter
|
hidden_states, forward_batch, should_allreduce_fusion, use_reduce_scatter
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_fuse_mlp_allreduce:
|
if should_allreduce_fusion:
|
||||||
hidden_states._sglang_needs_allreduce_fusion = True
|
hidden_states._sglang_needs_allreduce_fusion = True
|
||||||
|
|
||||||
if not can_fuse_mlp_allreduce:
|
if not should_allreduce_fusion:
|
||||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
@@ -1990,6 +1985,26 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def _build_fuse_allreduce_lookup_table(self):
|
||||||
|
static_conditions_met = (
|
||||||
|
self.layer_id != self.config.num_hidden_layers - 1
|
||||||
|
and get_tensor_model_parallel_world_size() > 1
|
||||||
|
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
||||||
|
and _is_sm100_supported
|
||||||
|
and _is_flashinfer_available
|
||||||
|
)
|
||||||
|
|
||||||
|
if not static_conditions_met:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
lookup_table = {}
|
||||||
|
for batch_size in range(129): # 0 to 128
|
||||||
|
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
||||||
|
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
||||||
|
lookup_table[batch_size] = should_fuse
|
||||||
|
|
||||||
|
return lookup_table
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2Model(nn.Module):
|
class DeepseekV2Model(nn.Module):
|
||||||
fall_back_to_pt_during_load = False
|
fall_back_to_pt_during_load = False
|
||||||
|
|||||||
@@ -154,13 +154,13 @@ class Glm4MoeMLP(nn.Module):
|
|||||||
)
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
|
def forward(self, x, forward_batch=None, should_allreduce_fusion=False):
|
||||||
if (self.tp_size == 1) and x.shape[0] == 0:
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.down_proj(x, skip_all_reduce=can_fuse_mlp_allreduce)
|
x, _ = self.down_proj(x, skip_all_reduce=should_allreduce_fusion)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -529,7 +529,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
def forward_normal_dual_stream(
|
def forward_normal_dual_stream(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
can_fuse_mlp_allreduce: bool = False,
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
@@ -553,7 +553,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
if (
|
if (
|
||||||
self.tp_size > 1
|
self.tp_size > 1
|
||||||
and not can_fuse_mlp_allreduce
|
and not should_allreduce_fusion
|
||||||
and not use_reduce_scatter
|
and not use_reduce_scatter
|
||||||
):
|
):
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
@@ -564,7 +564,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
final_hidden_states += shared_output
|
final_hidden_states += shared_output
|
||||||
if (
|
if (
|
||||||
self.tp_size > 1
|
self.tp_size > 1
|
||||||
and not can_fuse_mlp_allreduce
|
and not should_allreduce_fusion
|
||||||
and not use_reduce_scatter
|
and not use_reduce_scatter
|
||||||
):
|
):
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
@@ -575,13 +575,13 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
def forward_normal(
|
def forward_normal(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
can_fuse_mlp_allreduce: bool = False,
|
should_allreduce_fusion: bool = False,
|
||||||
use_reduce_scatter: bool = False,
|
use_reduce_scatter: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||||
self.shared_experts.gate_up_proj
|
self.shared_experts.gate_up_proj
|
||||||
):
|
):
|
||||||
return self.forward_cpu(hidden_states, can_fuse_mlp_allreduce)
|
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
||||||
|
|
||||||
shared_output = self._forward_shared_experts(hidden_states)
|
shared_output = self._forward_shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
@@ -596,7 +596,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
# fused in biased_grouped_topk so we can skip here
|
# fused in biased_grouped_topk so we can skip here
|
||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
if self.ep_size > 1:
|
if self.ep_size > 1:
|
||||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
if self.tp_size > 1 and not should_allreduce_fusion:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states
|
final_hidden_states
|
||||||
)
|
)
|
||||||
@@ -605,7 +605,7 @@ class Glm4MoeSparseMoeBlock(DeepseekV2MoE):
|
|||||||
else:
|
else:
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states += shared_output
|
final_hidden_states += shared_output
|
||||||
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
if self.tp_size > 1 and not should_allreduce_fusion:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||||
final_hidden_states
|
final_hidden_states
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|||||||
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope
|
from sglang.srt.layers.rotary_embedding import get_rope
|
||||||
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
@@ -64,7 +64,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.utils import add_prefix, make_layers
|
from sglang.srt.utils import add_prefix, is_cuda, is_flashinfer_available, make_layers
|
||||||
|
|
||||||
|
_is_flashinfer_available = is_flashinfer_available()
|
||||||
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||||
|
|
||||||
|
|
||||||
class GptOssConfig(PretrainedConfig):
|
class GptOssConfig(PretrainedConfig):
|
||||||
@@ -151,10 +154,13 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
forward_batch: Optional[ForwardBatch] = None,
|
||||||
|
should_allreduce_fusion: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
|
||||||
return self.forward_normal(hidden_states)
|
return self.forward_normal(hidden_states, should_allreduce_fusion)
|
||||||
else:
|
else:
|
||||||
raise Exception("forward_deepep branch not implemented yet")
|
raise Exception("forward_deepep branch not implemented yet")
|
||||||
|
|
||||||
@@ -165,7 +171,11 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
if name not in ["correction_bias"]
|
if name not in ["correction_bias"]
|
||||||
]
|
]
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward_normal(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
should_allreduce_fusion: bool = False,
|
||||||
|
) -> torch.Tensor:
|
||||||
num_tokens, hidden_dim = hidden_states.shape
|
num_tokens, hidden_dim = hidden_states.shape
|
||||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||||
|
|
||||||
@@ -179,7 +189,7 @@ class GptOssSparseMoeBlock(nn.Module):
|
|||||||
kwargs["topk_output"] = (self.top_k, router_logits)
|
kwargs["topk_output"] = (self.top_k, router_logits)
|
||||||
final_hidden_states = self.experts(**kwargs)
|
final_hidden_states = self.experts(**kwargs)
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1 and not should_allreduce_fusion:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
|
|
||||||
ans = final_hidden_states.view(num_tokens, hidden_dim)
|
ans = final_hidden_states.view(num_tokens, hidden_dim)
|
||||||
@@ -370,6 +380,7 @@ class GptOssDecoderLayer(nn.Module):
|
|||||||
|
|
||||||
# GptOss all layers are sparse and have no nextn now
|
# GptOss all layers are sparse and have no nextn now
|
||||||
self.is_layer_sparse = True
|
self.is_layer_sparse = True
|
||||||
|
self.is_nextn = False
|
||||||
is_previous_layer_sparse = True
|
is_previous_layer_sparse = True
|
||||||
|
|
||||||
self.layer_scatter_modes = LayerScatterModes.init_new(
|
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||||
@@ -402,6 +413,42 @@ class GptOssDecoderLayer(nn.Module):
|
|||||||
post_attention_layernorm=self.post_attention_layernorm,
|
post_attention_layernorm=self.post_attention_layernorm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._fuse_allreduce_lookup_table = self._build_fuse_allreduce_lookup_table()
|
||||||
|
|
||||||
|
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
||||||
|
"""Check if MLP allreduce can be fused with next layer's residual_rmsnorm"""
|
||||||
|
|
||||||
|
batch_size = (
|
||||||
|
forward_batch.input_ids.shape[0]
|
||||||
|
if hasattr(forward_batch, "input_ids")
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch_size > 128:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self._fuse_allreduce_lookup_table.get(batch_size, False)
|
||||||
|
|
||||||
|
def _build_fuse_allreduce_lookup_table(self):
|
||||||
|
static_conditions_met = (
|
||||||
|
self.layer_id != self.config.num_hidden_layers - 1
|
||||||
|
and get_tensor_model_parallel_world_size() > 1
|
||||||
|
and global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False)
|
||||||
|
and _is_sm100_supported
|
||||||
|
and _is_flashinfer_available
|
||||||
|
)
|
||||||
|
|
||||||
|
if not static_conditions_met:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
lookup_table = {}
|
||||||
|
for batch_size in range(129): # 0 to 128
|
||||||
|
is_last_layer = self.layer_id == self.config.num_hidden_layers - 1
|
||||||
|
should_fuse = batch_size > 0 and batch_size <= 128 and not is_last_layer
|
||||||
|
lookup_table[batch_size] = should_fuse
|
||||||
|
|
||||||
|
return lookup_table
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -424,12 +471,21 @@ class GptOssDecoderLayer(nn.Module):
|
|||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states, forward_batch)
|
should_allreduce_fusion = (
|
||||||
|
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
||||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
and not self.is_nextn
|
||||||
hidden_states, residual, forward_batch
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
hidden_states = self.mlp(hidden_states, forward_batch, should_allreduce_fusion)
|
||||||
|
|
||||||
|
if should_allreduce_fusion:
|
||||||
|
hidden_states._sglang_needs_allreduce_fusion = True
|
||||||
|
|
||||||
|
if not should_allreduce_fusion:
|
||||||
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||||
|
hidden_states, residual, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1435,7 +1435,7 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--enable-flashinfer-allreduce-fusion",
|
"--enable-flashinfer-allreduce-fusion",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
|
help="Enable FlashInfer allreduce fusion with Residual RMSNorm.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--deepep-mode",
|
"--deepep-mode",
|
||||||
|
|||||||
Reference in New Issue
Block a user