refactor qwen moe code, use communicator to support tp+dp (#6581)
This commit is contained in:
@@ -95,6 +95,7 @@ from sglang.srt.utils import (
|
||||
get_int_env_var,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
is_non_idle_and_non_empty,
|
||||
log_info_on_rank0,
|
||||
)
|
||||
|
||||
@@ -206,14 +207,6 @@ class MoEGate(nn.Module):
|
||||
return logits
|
||||
|
||||
|
||||
def is_non_idle_and_non_empty(forward_mode, hidden_states):
|
||||
return (
|
||||
(forward_mode is not None)
|
||||
and not forward_mode.is_idle()
|
||||
and hidden_states.shape[0] > 0
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -32,6 +32,7 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
attn_tp_reduce_scatter,
|
||||
@@ -49,7 +50,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -114,22 +115,22 @@ class Qwen2MoeMLP(nn.Module):
|
||||
class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.layer_id = layer_id
|
||||
if self.tp_size > config.num_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {config.num_experts}."
|
||||
)
|
||||
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
|
||||
self.experts = MoEImpl(
|
||||
self.experts = get_moe_impl_class()(
|
||||
layer_id=self.layer_id,
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
@@ -159,7 +160,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
self.shared_expert = None
|
||||
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||
) -> torch.Tensor:
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
shared_output = None
|
||||
@@ -276,19 +279,6 @@ class Qwen2MoeAttention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class _FFNInputMode(Enum):
|
||||
# The MLP sublayer requires 1/tp_size tokens as input
|
||||
SCATTERED = auto()
|
||||
# The MLP sublayer requires all tokens as input
|
||||
FULL = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DecoderLayerInfo:
|
||||
is_sparse: bool
|
||||
ffn_input_mode: _FFNInputMode
|
||||
|
||||
|
||||
class Qwen2MoeDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -298,6 +288,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
@@ -322,16 +313,20 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
self.local_dp_size = get_local_attention_dp_size()
|
||||
|
||||
self.info = self._compute_info(config, layer_id=layer_id)
|
||||
previous_layer_info = self._compute_info(config, layer_id=layer_id - 1)
|
||||
self.input_is_scattered = (
|
||||
layer_id > 0
|
||||
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
||||
)
|
||||
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
||||
# Qwen2MoE all layers are sparse and have no nextn now
|
||||
self.is_layer_sparse = True
|
||||
is_previous_layer_sparse = True
|
||||
|
||||
if self.info.is_sparse:
|
||||
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||
layer_id=layer_id,
|
||||
num_layers=config.num_hidden_layers,
|
||||
is_layer_sparse=self.is_layer_sparse,
|
||||
is_previous_layer_sparse=is_previous_layer_sparse,
|
||||
)
|
||||
|
||||
if self.is_layer_sparse:
|
||||
self.mlp = Qwen2MoeSparseMoeBlock(
|
||||
layer_id=layer_id,
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("mlp", prefix),
|
||||
@@ -348,27 +343,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _enable_moe_dense_fully_dp():
|
||||
return global_server_args_dict["moe_dense_tp_size"] == 1
|
||||
|
||||
@staticmethod
|
||||
def _compute_info(config: PretrainedConfig, layer_id: int):
|
||||
# WARN: Qwen2MOE has no dense_layer, it is only for compatibility.
|
||||
mlp_only_layers = (
|
||||
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
||||
self.layer_communicator = LayerCommunicator(
|
||||
layer_scatter_modes=self.layer_scatter_modes,
|
||||
input_layernorm=self.input_layernorm,
|
||||
post_attention_layernorm=self.post_attention_layernorm,
|
||||
)
|
||||
is_sparse = (layer_id not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
||||
)
|
||||
ffn_input_mode = (
|
||||
_FFNInputMode.SCATTERED
|
||||
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
|
||||
or (Qwen2MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
|
||||
else _FFNInputMode.FULL
|
||||
)
|
||||
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -377,108 +356,11 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
||||
return self.forward_ffn_with_scattered_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
||||
return self.forward_ffn_with_full_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_ffn_with_full_input(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
if self.local_dp_size != 1:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
# TODO extract this bugfix
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
# TODO extract this bugfix
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
elif hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
# TODO: use reduce-scatter in MLP to avoid this scatter
|
||||
# Scatter
|
||||
if self.local_dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward_ffn_with_scattered_input(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
if self.attn_tp_size != 1 and self.input_is_scattered:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
# Self Attention
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
@@ -486,47 +368,15 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if self.attn_tp_size != 1:
|
||||
if self.input_is_scattered:
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
else:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
residual = hidden_states
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
if not (
|
||||
self._enable_moe_dense_fully_dp()
|
||||
and (not self.info.is_sparse)
|
||||
and hidden_states.shape[0] == 0
|
||||
):
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
hidden_states = self.mlp(hidden_states, forward_batch)
|
||||
|
||||
if self.is_last_layer and self.attn_tp_size != 1:
|
||||
hidden_states += residual
|
||||
residual = None
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
attn_tp_reduce_scatter,
|
||||
@@ -78,7 +79,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
||||
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
||||
from sglang.srt.utils import DeepEPMode, add_prefix
|
||||
from sglang.srt.utils import DeepEPMode, add_prefix, is_non_idle_and_non_empty
|
||||
|
||||
Qwen3MoeConfig = None
|
||||
|
||||
@@ -150,13 +151,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
|
||||
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
||||
) -> torch.Tensor:
|
||||
|
||||
if not global_server_args_dict["enable_deepep_moe"]:
|
||||
return self.forward_normal(hidden_states)
|
||||
else:
|
||||
return self.forward_deepep(hidden_states, forward_mode)
|
||||
return self.forward_deepep(hidden_states, forward_batch)
|
||||
|
||||
def get_moe_weights(self):
|
||||
return [
|
||||
@@ -180,13 +181,10 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
def forward_deepep(
|
||||
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
||||
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
||||
) -> torch.Tensor:
|
||||
if (
|
||||
forward_mode is not None
|
||||
and not forward_mode.is_idle()
|
||||
and hidden_states.shape[0] > 0
|
||||
):
|
||||
forward_mode = forward_batch.forward_mode
|
||||
if is_non_idle_and_non_empty(forward_mode, hidden_states):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
@@ -356,19 +354,6 @@ class Qwen3MoeAttention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class _FFNInputMode(Enum):
|
||||
# The MLP sublayer requires 1/tp_size tokens as input
|
||||
SCATTERED = auto()
|
||||
# The MLP sublayer requires all tokens as input
|
||||
FULL = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DecoderLayerInfo:
|
||||
is_sparse: bool
|
||||
ffn_input_mode: _FFNInputMode
|
||||
|
||||
|
||||
class Qwen3MoeDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -378,6 +363,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
@@ -408,15 +394,18 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
self.local_dp_size = get_local_attention_dp_size()
|
||||
|
||||
self.info = self._compute_info(config, layer_id=layer_id)
|
||||
previous_layer_info = self._compute_info(config, layer_id=layer_id - 1)
|
||||
self.input_is_scattered = (
|
||||
layer_id > 0
|
||||
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
||||
)
|
||||
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
||||
# Qwen3MoE all layers are sparse and have no nextn now
|
||||
self.is_layer_sparse = True
|
||||
is_previous_layer_sparse = True
|
||||
|
||||
if self.info.is_sparse:
|
||||
self.layer_scatter_modes = LayerScatterModes.init_new(
|
||||
layer_id=layer_id,
|
||||
num_layers=config.num_hidden_layers,
|
||||
is_layer_sparse=self.is_layer_sparse,
|
||||
is_previous_layer_sparse=is_previous_layer_sparse,
|
||||
)
|
||||
|
||||
if self.is_layer_sparse:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(
|
||||
layer_id=self.layer_id,
|
||||
config=config,
|
||||
@@ -436,26 +425,11 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _enable_moe_dense_fully_dp():
|
||||
return global_server_args_dict["moe_dense_tp_size"] == 1
|
||||
|
||||
@staticmethod
|
||||
def _compute_info(config: PretrainedConfig, layer_id: int):
|
||||
# WARN: Qwen3MOE has no dense_layer, it is only for compatibility.
|
||||
mlp_only_layers = (
|
||||
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
||||
self.layer_communicator = LayerCommunicator(
|
||||
layer_scatter_modes=self.layer_scatter_modes,
|
||||
input_layernorm=self.input_layernorm,
|
||||
post_attention_layernorm=self.post_attention_layernorm,
|
||||
)
|
||||
is_sparse = (layer_id not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
||||
)
|
||||
ffn_input_mode = (
|
||||
_FFNInputMode.SCATTERED
|
||||
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
|
||||
or (Qwen3MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
|
||||
else _FFNInputMode.FULL
|
||||
)
|
||||
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -464,105 +438,11 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
||||
return self.forward_ffn_with_scattered_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
||||
return self.forward_ffn_with_full_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_ffn_with_full_input(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states, residual = self.layer_communicator.prepare_attn(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
if self.local_dp_size != 1:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
# TODO extract this bugfix
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
elif hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
|
||||
# TODO: use reduce-scatter in MLP to avoid this scatter
|
||||
# Scatter
|
||||
if self.local_dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward_ffn_with_scattered_input(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
if self.attn_tp_size != 1 and self.input_is_scattered:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
# Self Attention
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
@@ -570,47 +450,15 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if self.attn_tp_size != 1:
|
||||
if self.input_is_scattered:
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
else:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
residual = hidden_states
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
if not (
|
||||
self._enable_moe_dense_fully_dp()
|
||||
and (not self.info.is_sparse)
|
||||
and hidden_states.shape[0] == 0
|
||||
):
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
hidden_states = self.mlp(hidden_states, forward_batch)
|
||||
|
||||
if self.is_last_layer and self.attn_tp_size != 1:
|
||||
hidden_states += residual
|
||||
residual = None
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
||||
hidden_states, residual, forward_batch
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -2026,6 +2026,14 @@ class DeepEPMode(Enum):
|
||||
return DeepEPMode.normal
|
||||
|
||||
|
||||
def is_non_idle_and_non_empty(forward_mode, hidden_states):
|
||||
return (
|
||||
(forward_mode is not None)
|
||||
and not forward_mode.is_idle()
|
||||
and hidden_states.shape[0] > 0
|
||||
)
|
||||
|
||||
|
||||
def fast_topk(values, topk, dim):
|
||||
if topk == 1:
|
||||
# Use max along the specified dimension to get both value and index
|
||||
|
||||
Reference in New Issue
Block a user