Co-authored-by: King.Zevin <zevin@mail.ustc.edu.cn> Co-authored-by: Yi Zhang <1109276519@qq.com>
This commit is contained in:
@@ -269,6 +269,7 @@ def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
|
||||
batch,
|
||||
dp_size=model_runner.server_args.dp_size,
|
||||
attn_tp_size=1,
|
||||
moe_dense_tp_size=model_runner.server_args.moe_dense_tp_size,
|
||||
tp_cpu_group=model_runner.tp_group.cpu_group,
|
||||
get_idle_batch=None,
|
||||
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
||||
|
||||
@@ -142,16 +142,6 @@ def get_local_attention_dp_size():
|
||||
return _LOCAL_ATTN_DP_SIZE
|
||||
|
||||
|
||||
def get_local_attention_dp_rank():
|
||||
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
|
||||
return _LOCAL_ATTN_DP_RANK
|
||||
|
||||
|
||||
def get_local_attention_dp_size():
|
||||
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
|
||||
return _LOCAL_ATTN_DP_SIZE
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_dp_size():
|
||||
"""Patch the tp group temporarily until this function ends.
|
||||
|
||||
@@ -16,6 +16,8 @@
|
||||
# https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/qwen2_moe.py
|
||||
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
@@ -28,6 +30,15 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
attn_tp_reduce_scatter,
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -35,7 +46,7 @@ from sglang.srt.layers.linear import (
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -82,8 +93,7 @@ class Qwen2MoeMLP(nn.Module):
|
||||
)
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(
|
||||
f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now."
|
||||
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
||||
)
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
@@ -160,7 +170,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
)
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
@@ -182,20 +191,23 @@ class Qwen2MoeAttention(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
assert self.total_num_heads % attn_tp_size == 0
|
||||
self.num_heads = self.total_num_heads // attn_tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
if self.total_num_kv_heads >= attn_tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
assert self.total_num_kv_heads % attn_tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
assert attn_tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
@@ -210,6 +222,8 @@ class Qwen2MoeAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=qkv_bias,
|
||||
quant_config=quant_config,
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
@@ -218,6 +232,9 @@ class Qwen2MoeAttention(nn.Module):
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
@@ -252,6 +269,19 @@ class Qwen2MoeAttention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class _FFNInputMode(Enum):
|
||||
# The MLP sublayer requires 1/tp_size tokens as input
|
||||
SCATTERED = auto()
|
||||
# The MLP sublayer requires all tokens as input
|
||||
FULL = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DecoderLayerInfo:
|
||||
is_sparse: bool
|
||||
ffn_input_mode: _FFNInputMode
|
||||
|
||||
|
||||
class Qwen2MoeDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -279,14 +309,21 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
||||
# `mlp_only_layers` in the config.
|
||||
mlp_only_layers = (
|
||||
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
self.local_dp_size = get_local_attention_dp_size()
|
||||
|
||||
self.info = self._compute_info(config, layer_id=layer_id)
|
||||
previous_layer_info = self._compute_info(config, layer_id=layer_id - 1)
|
||||
self.input_is_scattered = (
|
||||
layer_id > 0
|
||||
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
||||
)
|
||||
if (layer_id not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
||||
):
|
||||
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
||||
|
||||
if self.info.is_sparse:
|
||||
self.mlp = Qwen2MoeSparseMoeBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
@@ -305,28 +342,185 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _enable_moe_dense_fully_dp():
|
||||
return global_server_args_dict["moe_dense_tp_size"] == 1
|
||||
|
||||
@staticmethod
|
||||
def _compute_info(config: PretrainedConfig, layer_id: int):
|
||||
# WARN: Qwen2MOE has no dense_layer, it is only for compatibility.
|
||||
mlp_only_layers = (
|
||||
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
||||
)
|
||||
is_sparse = (layer_id not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
||||
)
|
||||
ffn_input_mode = (
|
||||
_FFNInputMode.SCATTERED
|
||||
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
|
||||
or (Qwen2MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
|
||||
else _FFNInputMode.FULL
|
||||
)
|
||||
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
||||
return self.forward_ffn_with_scattered_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
||||
return self.forward_ffn_with_full_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_ffn_with_full_input(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
# all gather and all reduce
|
||||
if self.local_dp_size != 1:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
# TODO extract this bugfix
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
# TODO extract this bugfix
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
elif hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
# TODO: use reduce-scatter in MLP to avoid this scatter
|
||||
# Scatter
|
||||
if self.local_dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward_ffn_with_scattered_input(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
if self.attn_tp_size != 1 and self.input_is_scattered:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
# Self Attention
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if self.attn_tp_size != 1:
|
||||
if self.input_is_scattered:
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
else:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
residual = hidden_states
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
if not (
|
||||
self._enable_moe_dense_fully_dp()
|
||||
and (not self.info.is_sparse)
|
||||
and hidden_states.shape[0] == 0
|
||||
):
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
|
||||
if self.is_last_layer and self.attn_tp_size != 1:
|
||||
hidden_states += residual
|
||||
residual = None
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -345,6 +539,7 @@ class Qwen2MoeModel(nn.Module):
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
||||
prefix=add_prefix("embed_tokens", prefix),
|
||||
)
|
||||
# Use the provided decoder layer type or default to Qwen2MoeDecoderLayer
|
||||
@@ -379,12 +574,12 @@ class Qwen2MoeModel(nn.Module):
|
||||
hidden_states, residual = layer(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Qwen2MoeForCausalLM(nn.Module):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
@@ -414,7 +609,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
|
||||
@@ -17,12 +17,15 @@
|
||||
|
||||
"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
@@ -32,6 +35,15 @@ from sglang.srt.distributed import (
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
attn_tp_all_gather,
|
||||
attn_tp_reduce_scatter,
|
||||
dp_gather_partial,
|
||||
dp_scatter,
|
||||
get_attention_tp_rank,
|
||||
get_attention_tp_size,
|
||||
get_local_attention_dp_size,
|
||||
)
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -39,7 +51,7 @@ from sglang.srt.layers.linear import (
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -128,20 +140,23 @@ class Qwen3MoeAttention(nn.Module):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
attn_tp_rank = get_attention_tp_rank()
|
||||
attn_tp_size = get_attention_tp_size()
|
||||
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % self.tp_size == 0
|
||||
self.num_heads = self.total_num_heads // self.tp_size
|
||||
assert self.total_num_heads % attn_tp_size == 0
|
||||
self.num_heads = self.total_num_heads // attn_tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= self.tp_size:
|
||||
if self.total_num_kv_heads >= attn_tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % self.tp_size == 0
|
||||
assert self.total_num_kv_heads % attn_tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
||||
assert attn_tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
|
||||
self.head_dim = head_dim or hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
@@ -157,6 +172,8 @@ class Qwen3MoeAttention(nn.Module):
|
||||
self.total_num_kv_heads,
|
||||
bias=attention_bias,
|
||||
quant_config=quant_config,
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
prefix=add_prefix("qkv_proj", prefix),
|
||||
)
|
||||
|
||||
@@ -165,6 +182,9 @@ class Qwen3MoeAttention(nn.Module):
|
||||
hidden_size,
|
||||
bias=attention_bias,
|
||||
quant_config=quant_config,
|
||||
tp_rank=attn_tp_rank,
|
||||
tp_size=attn_tp_size,
|
||||
reduce_results=False,
|
||||
prefix=add_prefix("o_proj", prefix),
|
||||
)
|
||||
|
||||
@@ -213,6 +233,19 @@ class Qwen3MoeAttention(nn.Module):
|
||||
return output
|
||||
|
||||
|
||||
class _FFNInputMode(Enum):
|
||||
# The MLP sublayer requires 1/tp_size tokens as input
|
||||
SCATTERED = auto()
|
||||
# The MLP sublayer requires all tokens as input
|
||||
FULL = auto()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DecoderLayerInfo:
|
||||
is_sparse: bool
|
||||
ffn_input_mode: _FFNInputMode
|
||||
|
||||
|
||||
class Qwen3MoeDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -246,14 +279,21 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
prefix=add_prefix("self_attn", prefix),
|
||||
)
|
||||
|
||||
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
|
||||
# `mlp_only_layers` in the config.
|
||||
mlp_only_layers = (
|
||||
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.attn_tp_size = get_attention_tp_size()
|
||||
self.attn_tp_rank = get_attention_tp_rank()
|
||||
self.local_dp_size = get_local_attention_dp_size()
|
||||
|
||||
self.info = self._compute_info(config, layer_id=layer_id)
|
||||
previous_layer_info = self._compute_info(config, layer_id=layer_id - 1)
|
||||
self.input_is_scattered = (
|
||||
layer_id > 0
|
||||
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
||||
)
|
||||
if (layer_id not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
||||
):
|
||||
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
||||
|
||||
if self.info.is_sparse:
|
||||
self.mlp = Qwen3MoeSparseMoeBlock(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
@@ -272,28 +312,182 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _enable_moe_dense_fully_dp():
|
||||
return global_server_args_dict["moe_dense_tp_size"] == 1
|
||||
|
||||
@staticmethod
|
||||
def _compute_info(config: PretrainedConfig, layer_id: int):
|
||||
# WARN: Qwen3MOE has no dense_layer, it is only for compatibility.
|
||||
mlp_only_layers = (
|
||||
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
|
||||
)
|
||||
is_sparse = (layer_id not in mlp_only_layers) and (
|
||||
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
|
||||
)
|
||||
ffn_input_mode = (
|
||||
_FFNInputMode.SCATTERED
|
||||
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
|
||||
or (Qwen3MoeDecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
|
||||
else _FFNInputMode.FULL
|
||||
)
|
||||
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
|
||||
return self.forward_ffn_with_scattered_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
||||
return self.forward_ffn_with_full_input(
|
||||
positions, hidden_states, forward_batch, residual
|
||||
)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
raise NotImplementedError
|
||||
|
||||
def forward_ffn_with_full_input(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
# Self Attention
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
# Gather
|
||||
if get_tensor_model_parallel_world_size() > 1:
|
||||
if self.local_dp_size != 1:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer,
|
||||
hidden_states,
|
||||
)
|
||||
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
||||
dp_scatter(residual, hidden_states, forward_batch)
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
# TODO extract this bugfix
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
elif hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
|
||||
# TODO: use reduce-scatter in MLP to avoid this scatter
|
||||
# Scatter
|
||||
if self.local_dp_size != 1:
|
||||
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
||||
# be careful about this!
|
||||
hidden_states, global_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
def forward_ffn_with_scattered_input(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if hidden_states.shape[0] == 0:
|
||||
residual = hidden_states
|
||||
else:
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
if self.attn_tp_size != 1 and self.input_is_scattered:
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
# Self Attention
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
forward_batch=forward_batch,
|
||||
)
|
||||
|
||||
if self.attn_tp_size != 1:
|
||||
if self.input_is_scattered:
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
else:
|
||||
if self.attn_tp_rank == 0:
|
||||
hidden_states += residual
|
||||
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
||||
hidden_states = tensor_list[self.attn_tp_rank]
|
||||
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
||||
residual = hidden_states
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
else:
|
||||
if hidden_states.shape[0] != 0:
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
if not (
|
||||
self._enable_moe_dense_fully_dp()
|
||||
and (not self.info.is_sparse)
|
||||
and hidden_states.shape[0] == 0
|
||||
):
|
||||
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
|
||||
|
||||
if self.is_last_layer and self.attn_tp_size != 1:
|
||||
hidden_states += residual
|
||||
residual = None
|
||||
hidden_states, local_hidden_states = (
|
||||
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
||||
hidden_states,
|
||||
)
|
||||
attn_tp_all_gather(
|
||||
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
||||
)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@@ -313,7 +507,6 @@ class Qwen3MoeModel(Qwen2MoeModel):
|
||||
|
||||
|
||||
class Qwen3MoeForCausalLM(nn.Module):
|
||||
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(
|
||||
@@ -343,7 +536,7 @@ class Qwen3MoeForCausalLM(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
) -> LogitsProcessorOutput:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
|
||||
Reference in New Issue
Block a user