Files
sglang/python/sglang/srt/models/bailing_moe.py
Yuan Luo 24dc2bee97 Fix Bailing MoE model bugs (#10362)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
Co-authored-by: 羽癫 <yudian.zy@antgroup.com>
2025-09-12 00:36:02 -07:00

1006 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# coding=utf-8
# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" SGLang BailingMoE model."""
import logging
from typing import Any, Dict, Iterable, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_pp_group,
get_tensor_model_parallel_world_size,
parallel_state,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.communicator import (
LayerCommunicator,
LayerScatterModes,
enable_moe_dense_fully_dp,
)
from sglang.srt.layers.dp_attention import (
get_attention_dp_size,
get_attention_tp_rank,
get_attention_tp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe import get_moe_a2a_backend
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers
LoraConfig = None
logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
class BailingMoEMLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: Optional[bool] = True,
prefix: str = "",
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
) -> None:
super().__init__()
self.tp_size = tp_size
self.gate_up_proj = MergedColumnParallelLinear(
config.hidden_size,
[intermediate_size] * 2,
bias=config.use_bias,
quant_config=quant_config,
prefix=add_prefix("gate_up_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
)
self.down_proj = RowParallelLinear(
intermediate_size,
config.hidden_size,
bias=config.use_bias,
reduce_results=reduce_results,
quant_config=quant_config,
prefix=add_prefix("down_proj", prefix),
tp_rank=tp_rank,
tp_size=tp_size,
)
if config.hidden_act != "silu":
raise ValueError("Unsupported activation. Only silu is supported for now.")
self.act_fn = SiluAndMul()
def forward(
self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if (self.tp_size == 1) and hidden_states.shape[0] == 0:
return hidden_states
gate_up, _ = self.gate_up_proj(hidden_states)
hidden_states = self.act_fn(gate_up)
hidden_states, _ = self.down_proj(
hidden_states, skip_all_reduce=use_reduce_scatter
)
return hidden_states
class BailingMoEGate(nn.Module):
def __init__(
self,
config,
params_dtype: Optional[torch.dtype] = None,
prefix: str = "",
):
super().__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
self.weight = nn.Parameter(
torch.empty(
(config.num_experts, config.hidden_size),
dtype=self.params_dtype,
),
)
if getattr(config, "moe_router_enable_expert_bias", False):
self.expert_bias = nn.Parameter(
torch.empty((config.num_experts,), dtype=torch.float32),
)
else:
self.expert_bias = None
def forward(self, hidden_states):
logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, None).to(
hidden_states.dtype
)
return logits
class BailingMoESparseMoeBlock(nn.Module):
def __init__(
self,
layer_id: int,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
prefix: str = "",
):
super().__init__()
self.layer_id = layer_id
self.alt_stream = alt_stream
self.tp_size = get_tensor_model_parallel_world_size()
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
self.hidden_size = config.hidden_size
self.num_shared_experts = config.num_shared_experts
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
self.score_function = getattr(config, "score_function", None)
if config.hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now."
)
# Gate always runs at half / full precision for now.
router_dtype = getattr(config, "router_dtype", None)
if router_dtype is None:
self.router_dtype = None
elif router_dtype == "fp32":
self.router_dtype = torch.float32
else:
self.router_dtype = torch.bfloat16
# TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now
assert global_server_args_dict["ep_num_redundant_experts"] == 0
# check group topk
self.num_expert_group = getattr(config, "n_group", 0)
self.topk_group = getattr(config, "topk_group", 0)
if self.num_expert_group > 0 or self.topk_group > 0:
assert (
self.num_expert_group > 0
and 0 < self.topk_group <= self.num_expert_group
)
self.use_grouped_topk = True
else:
self.num_expert_group = self.topk_group = None
self.use_grouped_topk = False
self.num_experts = (
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
)
self.gate = BailingMoEGate(
config=config,
params_dtype=self.router_dtype,
prefix=add_prefix("gate", prefix),
)
self.correction_bias = (
self.gate.expert_bias.data if self.gate.expert_bias is not None else None
)
if self.score_function is not None:
assert (
self.score_function == "softmax" and self.correction_bias is None
) or (
self.score_function == "sigmoid" and self.correction_bias is not None
), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)"
self.topk = TopK(
top_k=self.top_k,
renormalize=self.norm_topk_prob,
use_grouped_topk=self.use_grouped_topk,
num_expert_group=self.num_expert_group,
# num_fused_shared_experts=self.num_fused_shared_experts,
topk_group=self.topk_group,
correction_bias=self.correction_bias,
routed_scaling_factor=self.routed_scaling_factor,
)
self.experts = get_moe_impl_class()(
num_experts=self.num_experts,
top_k=self.top_k,
layer_id=self.layer_id,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
quant_config=quant_config,
routed_scaling_factor=self.routed_scaling_factor,
prefix=add_prefix("experts", prefix),
)
# shared expert
if config.num_shared_experts is not None:
if hasattr(config, "moe_shared_expert_intermediate_size"):
intermediate_size = config.moe_shared_expert_intermediate_size
else:
intermediate_size = config.moe_intermediate_size
intermediate_size *= config.num_shared_experts
# disable tp for shared experts when enable deepep moe
self.shared_experts = BailingMoEMLP(
intermediate_size=intermediate_size,
config=config,
quant_config=quant_config,
reduce_results=False,
prefix=add_prefix("shared_experts", prefix),
**(
dict(tp_rank=0, tp_size=1)
if get_moe_a2a_backend().is_deepep()
else {}
),
)
# dispatcher
if get_moe_a2a_backend().is_deepep():
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.deepep_dispatcher = DeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=self.num_experts,
num_local_experts=config.num_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
async_finish=True, # TODO
return_recv_hook=True,
)
def forward(
self,
hidden_states: torch.Tensor,
forward_batch: Optional[ForwardBatch] = None,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
if not get_moe_a2a_backend().is_deepep():
return self.forward_normal(hidden_states, use_reduce_scatter)
else:
return self.forward_deepep(hidden_states, forward_batch)
def get_moe_weights(self):
return [
x.data
for name, x in self.experts.named_parameters()
if name not in ["correction_bias"]
]
def _forward_shared_experts(self, hidden_states: torch.Tensor):
shared_output = None
if self.num_shared_experts > 0:
shared_output = self.shared_experts(hidden_states)
return shared_output
def _forward_router_experts(self, hidden_states: torch.Tensor):
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states)
topk_output = self.topk(hidden_states, router_logits)
return self.experts(hidden_states, topk_output)
def forward_normal_dual_stream(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
shared_output = self._forward_shared_experts(hidden_states.clone())
with torch.cuda.stream(self.alt_stream):
router_output = self._forward_router_experts(hidden_states)
current_stream.wait_stream(self.alt_stream)
return router_output, shared_output
def forward_normal(
self,
hidden_states: torch.Tensor,
use_reduce_scatter: bool = False,
) -> torch.Tensor:
num_tokens, hidden_size = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_size)
DUAL_STREAM_TOKEN_THRESHOLD = 1024
if (
self.alt_stream is not None
and hidden_states.shape[0] > 0
and hidden_states.shape[0] <= DUAL_STREAM_TOKEN_THRESHOLD
and get_is_capture_mode()
):
final_hidden_states, shared_output = self.forward_normal_dual_stream(
hidden_states
)
else:
shared_output = self._forward_shared_experts(hidden_states)
final_hidden_states = self._forward_router_experts(hidden_states)
if self.num_shared_experts > 0:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1 and not use_reduce_scatter:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_size)
def forward_deepep(
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
) -> torch.Tensor:
shared_output = None
forward_mode = forward_batch.forward_mode
if is_non_idle_and_non_empty(forward_mode, hidden_states):
router_logits = self.gate(hidden_states)
if self.num_shared_experts > 0:
shared_output = self.shared_experts(hidden_states)
topk_weights, topk_idx, _ = self.topk(
hidden_states,
router_logits,
num_token_non_padded=forward_batch.num_token_non_padded,
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
layer_id=self.layer_id,
),
)
else:
topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if self.ep_size > 1:
(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr,
masked_m,
expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states,
topk_idx,
topk_weights,
forward_batch=forward_batch,
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_batch=forward_batch,
)
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states,
topk_idx,
topk_weights,
forward_batch=forward_batch,
)
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
return final_hidden_states
class BailingMoEAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
reduce_results: bool = True,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
self.total_num_heads = config.num_attention_heads
self.total_kv_heads = config.num_key_value_heads
self.dp_size = get_attention_dp_size()
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
assert self.total_num_heads % attn_tp_size == 0
assert self.total_kv_heads % attn_tp_size == 0
assert self.total_num_heads >= self.total_kv_heads
self.num_heads = self.total_num_heads // attn_tp_size
self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads)
self.q_size = self.head_dim * self.num_heads
self.num_kv_heads = self.total_kv_heads // attn_tp_size
self.kv_size = max(1, self.num_kv_heads * self.head_dim)
self.scale = self.head_dim**-0.5
self.use_qk_norm = getattr(config, "use_qk_norm", False)
self.query_key_value = QKVParallelLinear(
self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_kv_heads,
bias=(config.use_bias or config.use_qkv_bias),
quant_config=quant_config,
prefix=add_prefix("query_key_value", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
if self.use_qk_norm:
self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.dense = RowParallelLinear(
self.total_num_heads * self.head_dim,
self.hidden_size,
bias=config.use_bias,
quant_config=quant_config,
reduce_results=reduce_results,
prefix=add_prefix("dense", prefix),
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
)
if hasattr(config, "partial_rotary_factor"):
self.rotary_dim = int(self.head_dim * config.partial_rotary_factor)
elif hasattr(config, "rotary_dim"):
self.rotary_dim = config.rotary_dim
else:
self.rotary_dim = self.head_dim
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.rotary_dim,
max_position=config.max_position_embeddings,
base=config.rope_theta,
rope_scaling=config.rope_scaling,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
)
self.alt_stream = alt_stream
def _apply_qk_norm(
self, q: torch.Tensor, k: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# overlap qk norm
if self.alt_stream is not None and get_is_capture_mode():
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.query_layernorm(q_by_head)
with torch.cuda.stream(self.alt_stream):
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.key_layernorm(k_by_head)
current_stream.wait_stream(self.alt_stream)
else:
q_by_head = q.reshape(-1, self.head_dim)
q_by_head = self.query_layernorm(q_by_head)
k_by_head = k.reshape(-1, self.head_dim)
k_by_head = self.key_layernorm(k_by_head)
q = q_by_head.view(q.shape)
k = k_by_head.view(k.shape)
return q, k
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if hidden_states.shape[0] == 0:
return hidden_states
qkv, _ = self.query_key_value(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.use_qk_norm:
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
context_layer = self.attn(q, k, v, forward_batch)
attn_output, _ = self.dense(context_layer)
return attn_output
class BailingMoEBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
layer_id: int = 0,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
alt_stream: Optional[torch.cuda.Stream] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
self.dp_size = get_attention_dp_size()
self.attention = BailingMoEAttention(
config,
layer_id,
quant_config,
reduce_results=False,
prefix=add_prefix("attention", prefix),
alt_stream=alt_stream,
)
self.layer_id = layer_id
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.is_layer_sparse = self._is_layer_sparse(
config, layer_id=layer_id, is_nextn=False
)
is_previous_layer_sparse = self._is_layer_sparse(
config, layer_id=layer_id - 1, is_nextn=False
)
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
if self.is_layer_sparse:
self.mlp = BailingMoESparseMoeBlock(
layer_id=layer_id,
config=config,
quant_config=quant_config,
alt_stream=alt_stream,
prefix=add_prefix("mlp", prefix),
)
else:
if enable_moe_dense_fully_dp():
mlp_tp_rank, mlp_tp_size = 0, 1
else:
mlp_tp_rank, mlp_tp_size = None, None
self.mlp = BailingMoEMLP(
intermediate_size=config.intermediate_size,
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
tp_rank=mlp_tp_rank,
tp_size=mlp_tp_size,
)
self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
allow_reduce_scatter=True,
)
def _is_layer_sparse(
self, config: PretrainedConfig, layer_id: int, is_nextn: bool
) -> bool:
return is_nextn or (
config.num_experts is not None and layer_id >= config.first_k_dense_replace
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
hidden_states = self.attention(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
# For DP with padding, reduce scatter can be used instead of all-reduce.
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
forward_batch
)
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states=hidden_states,
residual=residual,
forward_batch=forward_batch,
)
return hidden_states, residual
class BailingMoEModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
alt_stream: Optional[torch.cuda.Stream] = None,
prefix: str = "",
):
super().__init__()
self.pp_group = get_pp_group()
self.config = config
self.vocab_size = config.vocab_size
self.embed_dim = config.hidden_size
if self.pp_group.is_first_rank:
self.word_embeddings = VocabParallelEmbedding(
self.vocab_size,
self.embed_dim,
quant_config=quant_config,
prefix=add_prefix("word_embeddings", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
else:
self.word_embeddings = PPMissingLayer()
self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout)
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: BailingMoEBlock(
layer_id=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
alt_stream=alt_stream,
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix),
)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.word_embeddings(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
with get_global_expert_distribution_recorder().with_current_layer(i):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
forward_batch,
residual,
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if not forward_batch.forward_mode.is_idle():
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
class BailingMoEForCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.pp_group = get_pp_group()
self.config = config
self.quant_config = quant_config
alt_stream = torch.cuda.Stream() if _is_cuda else None
self.model = BailingMoEModel(
config,
quant_config,
alt_stream=alt_stream,
prefix=add_prefix("model", ""),
)
# tie_word_embeddings为true复用tie_word_embeddings反之是独立的
if config.tie_word_embeddings:
self.lm_head = self.model.word_embeddings
else:
# TODO something wrong with ParallelLMHead with DP attention enabled
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def get_embed_and_head(self):
"""Used by the eagle_worker."""
return self.model.word_embeddings.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
"""Used by the eagle_worker."""
del self.model.word_embeddings.weight
del self.lm_head.weight
self.model.word_embeddings.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
if self.pp_group.is_last_rank:
return self.logits_processor(
input_ids, hidden_states, self.lm_head, forward_batch
)
else:
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
if is_nextn:
if hasattr(self.config, "num_nextn_predict_layers"):
num_nextn_layers = self.config.num_nextn_predict_layers
assert num_nextn_layers == 1, "Only 1 nextn layer is supported"
# compatible with old design
nextn_layer_id = (
0
if self.config.num_hidden_layers == 1
else self.config.num_hidden_layers
)
else:
raise ValueError("num_nextn_predict_layers is not in the config")
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
if is_nextn:
nextn_layer_prefix = f"model.layers.{nextn_layer_id}"
nextn_spec_weight_names = [
"final_layernorm",
"eh_proj",
"enorm",
"hnorm",
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = FusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if (
("v_head" in name)
or ("inv_freq" in name)
or (self.config.tie_word_embeddings and "lm_head" in name)
):
continue
if (
hasattr(self.config, "norm_head")
and self.config.norm_head
and "lm_head.weight" in name
):
import torch.nn.functional as F
loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7)
if is_nextn:
if not name.startswith(nextn_layer_prefix):
continue
# Use shared head and embed weights from target model
if "shared_head.head" in name or "embed_tokens" in name:
continue
is_decoder = True
# For nextn specific weights
for weight_name in nextn_spec_weight_names:
if weight_name in name:
name = name.replace(nextn_layer_prefix, "model")
is_decoder = False
break
# For decoder layer weights
if is_decoder:
name = name.replace(nextn_layer_prefix, "model.decoder")
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
if not is_nextn:
self.routed_experts_weights_of_layer = {
layer_id: layer.mlp.get_moe_weights()
for layer_id, layer in enumerate(self.model.layers)
if not isinstance(layer, PPMissingLayer)
and isinstance(layer.mlp, BailingMoESparseMoeBlock)
}
@classmethod
def get_model_config_for_expert_location(cls, config):
num_groups = getattr(config, "n_group", 0)
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.num_experts,
num_groups=None if num_groups == 0 else num_groups,
)
class BailingMoeForCausalLM(BailingMoEForCausalLM):
pass
class BailingMoeV2ForCausalLM(BailingMoEForCausalLM):
pass
EntryClass = [BailingMoEForCausalLM, BailingMoeForCausalLM, BailingMoeV2ForCausalLM]