[ready b200] fuse allreduce+add_rmsnorm in prepare_attention + mlp module (#7775)
This commit is contained in:
@@ -187,11 +187,24 @@ class LayerCommunicator:
|
|||||||
if hidden_states.shape[0] == 0:
|
if hidden_states.shape[0] == 0:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
else:
|
else:
|
||||||
if residual is None:
|
if (
|
||||||
residual = hidden_states
|
residual is not None
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
and hasattr(hidden_states, "_sglang_needs_allreduce_fusion")
|
||||||
|
and hidden_states._sglang_needs_allreduce_fusion
|
||||||
|
):
|
||||||
|
hidden_states, residual = (
|
||||||
|
self.input_layernorm.forward_with_allreduce_fusion(
|
||||||
|
hidden_states, residual
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, residual = self.input_layernorm(
|
||||||
|
hidden_states, residual
|
||||||
|
)
|
||||||
|
|
||||||
hidden_states = self._communicate_simple_fn(
|
hidden_states = self._communicate_simple_fn(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
|
|||||||
@@ -1367,7 +1367,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
# It does not support additional parameters.
|
# It does not support additional parameters.
|
||||||
param.load_row_parallel_weight(loaded_weight)
|
param.load_row_parallel_weight(loaded_weight)
|
||||||
|
|
||||||
def forward(self, input_):
|
def forward(self, input_, can_fuse_mlp_allreduce=False):
|
||||||
if self.input_is_parallel:
|
if self.input_is_parallel:
|
||||||
input_parallel = input_
|
input_parallel = input_
|
||||||
else:
|
else:
|
||||||
@@ -1382,7 +1382,7 @@ class RowParallelLinear(LinearBase):
|
|||||||
# bias will not get added more than once in TP>1 case)
|
# bias will not get added more than once in TP>1 case)
|
||||||
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
||||||
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||||
output = tensor_model_parallel_all_reduce(output_parallel)
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
||||||
else:
|
else:
|
||||||
output = output_parallel
|
output = output_parallel
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
||||||
|
from sglang.srt.layers.utils import is_sm100_supported
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
@@ -100,6 +101,7 @@ from sglang.srt.utils import (
|
|||||||
get_int_env_var,
|
get_int_env_var,
|
||||||
is_cpu,
|
is_cpu,
|
||||||
is_cuda,
|
is_cuda,
|
||||||
|
is_flashinfer_available,
|
||||||
is_hip,
|
is_hip,
|
||||||
is_non_idle_and_non_empty,
|
is_non_idle_and_non_empty,
|
||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
@@ -132,6 +134,9 @@ if _is_hip:
|
|||||||
decode_attention_fwd_grouped_rope,
|
decode_attention_fwd_grouped_rope,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_is_flashinfer_available = is_flashinfer_available()
|
||||||
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -195,13 +200,13 @@ class DeepseekV2MLP(nn.Module):
|
|||||||
)
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x, forward_batch=None):
|
def forward(self, x, forward_batch=None, can_fuse_mlp_allreduce=False):
|
||||||
if (self.tp_size == 1) and x.shape[0] == 0:
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
gate_up, _ = self.gate_up_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x = self.act_fn(gate_up)
|
x = self.act_fn(gate_up)
|
||||||
x, _ = self.down_proj(x)
|
x, _ = self.down_proj(x, can_fuse_mlp_allreduce=can_fuse_mlp_allreduce)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -409,7 +414,10 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
forward_batch: Optional[ForwardBatch] = None,
|
||||||
|
can_fuse_mlp_allreduce: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if not self._enable_deepep_moe:
|
if not self._enable_deepep_moe:
|
||||||
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
||||||
@@ -418,13 +426,17 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
and self.num_fused_shared_experts == 0
|
and self.num_fused_shared_experts == 0
|
||||||
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
|
||||||
):
|
):
|
||||||
return self.forward_normal_dual_stream(hidden_states)
|
return self.forward_normal_dual_stream(
|
||||||
|
hidden_states, can_fuse_mlp_allreduce
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self.forward_normal(hidden_states)
|
return self.forward_normal(hidden_states, can_fuse_mlp_allreduce)
|
||||||
else:
|
else:
|
||||||
return self.forward_deepep(hidden_states, forward_batch)
|
return self.forward_deepep(hidden_states, forward_batch)
|
||||||
|
|
||||||
def forward_normal_dual_stream(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward_normal_dual_stream(
|
||||||
|
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||||
|
) -> torch.Tensor:
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
|
|
||||||
@@ -440,11 +452,13 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
current_stream.wait_stream(self.alt_stream)
|
current_stream.wait_stream(self.alt_stream)
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward_normal(
|
||||||
|
self, hidden_states: torch.Tensor, can_fuse_mlp_allreduce: bool = False
|
||||||
|
) -> torch.Tensor:
|
||||||
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
||||||
self.shared_experts.gate_up_proj
|
self.shared_experts.gate_up_proj
|
||||||
):
|
):
|
||||||
@@ -461,7 +475,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
final_hidden_states *= self.routed_scaling_factor
|
final_hidden_states *= self.routed_scaling_factor
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
final_hidden_states = final_hidden_states + shared_output
|
final_hidden_states = final_hidden_states + shared_output
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
@@ -514,7 +528,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
None, # a2_scale
|
None, # a2_scale
|
||||||
True, # is_vnni
|
True, # is_vnni
|
||||||
)
|
)
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1 and not self.can_fuse_mlp_allreduce:
|
||||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||||
return final_hidden_states
|
return final_hidden_states
|
||||||
|
|
||||||
@@ -1818,6 +1832,29 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
and layer_id % self.config.moe_layer_freq == 0
|
and layer_id % self.config.moe_layer_freq == 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _should_fuse_mlp_allreduce_with_next_layer(self, forward_batch) -> bool:
|
||||||
|
"""Check if MLP allreduce can be fused with next layer's add_rmsnorm"""
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.layer_id == self.config.num_hidden_layers - 1
|
||||||
|
or get_tensor_model_parallel_world_size() <= 1
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not global_server_args_dict.get("enable_flashinfer_allreduce_fusion", False):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not _is_sm100_supported or not _is_flashinfer_available:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if hasattr(forward_batch, "input_ids") and (
|
||||||
|
forward_batch.input_ids.shape[0] == 0
|
||||||
|
or forward_batch.input_ids.shape[0] > 128
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
@@ -1842,12 +1879,27 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|||||||
hidden_states, residual, forward_batch
|
hidden_states, residual, forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.mlp(hidden_states, forward_batch)
|
can_fuse_mlp_allreduce = (
|
||||||
|
self._should_fuse_mlp_allreduce_with_next_layer(forward_batch)
|
||||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
and not (self.enable_dp_attention and self.speculative_algorithm.is_eagle())
|
||||||
hidden_states, residual, forward_batch
|
and not self.is_nextn
|
||||||
)
|
)
|
||||||
|
|
||||||
|
hidden_states = self.mlp(hidden_states, forward_batch, can_fuse_mlp_allreduce)
|
||||||
|
|
||||||
|
if can_fuse_mlp_allreduce:
|
||||||
|
hidden_states._sglang_needs_allreduce_fusion = True
|
||||||
|
|
||||||
|
if not can_fuse_mlp_allreduce:
|
||||||
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||||
|
hidden_states, residual, forward_batch
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.enable_dp_attention and self.speculative_algorithm.is_eagle():
|
||||||
|
# NOTE: this line resolves the degradation of MTP reception rate for non-zero DP ranks.
|
||||||
|
# See discussion here (https://github.com/sgl-project/sglang/pull/6081#discussion_r2147452251).
|
||||||
|
hidden_states = hidden_states.clone()
|
||||||
|
|
||||||
return hidden_states, residual
|
return hidden_states, residual
|
||||||
|
|
||||||
def op_comm_prepare_attn(
|
def op_comm_prepare_attn(
|
||||||
|
|||||||
Reference in New Issue
Block a user