diff --git a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py index 7b52f02a3..b1780f1a7 100644 --- a/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py +++ b/benchmark/kernels/fused_moe_triton/tuning_fused_moe_triton.py @@ -13,8 +13,8 @@ from ray.experimental.tqdm_ray import tqdm from transformers import AutoConfig from sglang.srt.layers.moe.fused_moe_triton import override_config -from sglang.srt.layers.moe.fused_moe_triton.fused_moe import ( - fused_moe, +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe +from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import ( get_config_dtype_str, get_config_file_name, get_default_config, @@ -441,6 +441,15 @@ def main(args: argparse.Namespace): topk = config.num_experts_per_tok intermediate_size = config.moe_intermediate_size shard_intermediate_size = 2 * intermediate_size // args.tp_size + elif config.architectures[0] in [ + "BailingMoEForCausalLM", + "BailingMoeForCausalLM", + "BailingMoeV2ForCausalLM", + ]: + E = config.num_experts + topk = config.num_experts_per_tok + intermediate_size = config.moe_intermediate_size + shard_intermediate_size = 2 * intermediate_size // args.tp_size elif config.architectures[0] in ["Glm4MoeForCausalLM"]: E = config.n_routed_experts topk = config.num_experts_per_tok diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 0dbe37aa0..d6f34b4d9 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -141,6 +141,11 @@ class ModelConfig: if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM": self.hf_config.architectures[0] = "MiMoMTP" + if is_draft_model and self.hf_config.architectures[0] in [ + "BailingMoeV2ForCausalLM", + "BailingMoeForCausalLM", + ]: + self.hf_config.architectures[0] = "BailingMoeForCausalLMNextN" if ( is_draft_model and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM" diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index df2b77e08..0765b673a 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -893,6 +893,35 @@ class QKVParallelLinear(ColumnParallelLinear): ) self.weight_loader_v2(param, loaded_weight_shard, shard_id) + def _load_qkv_block_scale( + self, param: BasevLLMParameter, loaded_weight: torch.Tensor + ): + block_n, _ = self.quant_method.quant_config.weight_block_size + q_size = self.total_num_heads * self.head_size // block_n + k_size = self.total_num_kv_heads * self.head_size // block_n + v_size = self.total_num_kv_heads * self.head_size // block_n + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, q_size), + ("k", q_size, k_size), + ("v", q_size + k_size, v_size), + ] + for shard_id, shard_offset, shard_size in shard_offsets: + loaded_weight_shard = loaded_weight.narrow( + param.output_dim, shard_offset, shard_size + ) + rank_shard_offset = self._get_shard_offset_mapping(shard_id) // block_n + rank_shard_size = self._get_shard_size_mapping(shard_id) // block_n + param.load_qkv_weight( + loaded_weight=loaded_weight_shard, + num_heads=self.num_kv_head_replicas, + shard_id=shard_id, + shard_offset=rank_shard_offset, + shard_size=rank_shard_size, + tp_rank=self.tp_rank, + use_presharded_weights=self.use_presharded_weights, + ) + def weight_loader_v2( self, param: BasevLLMParameter, @@ -906,6 +935,9 @@ class QKVParallelLinear(ColumnParallelLinear): elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_qkv_weight(loaded_weight=loaded_weight) return + elif isinstance(param, BlockQuantScaleParameter): + self._load_qkv_block_scale(param, loaded_weight) + return # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json new file mode 100644 index 000000000..786f36789 --- /dev/null +++ b/python/sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3 + } +} diff --git a/python/sglang/srt/models/bailing_moe.py b/python/sglang/srt/models/bailing_moe.py index 73e5a9a16..b16d22320 100644 --- a/python/sglang/srt/models/bailing_moe.py +++ b/python/sglang/srt/models/bailing_moe.py @@ -1,19 +1,51 @@ -# Copyright 2023-2024 SGLang Team -# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/bailing_moe.py - -from collections.abc import Iterable -from typing import Optional, Tuple +# coding=utf-8 +# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" SGLang BailingMoE model.""" +import logging +from typing import Any, Dict, Iterable, Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn -from transformers.configuration_utils import PretrainedConfig +from transformers import PretrainedConfig from sglang.srt.distributed import ( + get_pp_group, get_tensor_model_parallel_world_size, + parallel_state, tensor_model_parallel_all_reduce, ) +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.communicator import ( + LayerCommunicator, + LayerScatterModes, + enable_moe_dense_fully_dp, +) +from sglang.srt.layers.dp_attention import ( + get_attention_dp_size, + get_attention_tp_rank, + get_attention_tp_size, +) from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( MergedColumnParallelLinear, @@ -22,63 +54,457 @@ from sglang.srt.layers.linear import ( RowParallelLinear, ) from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.moe.fused_moe_triton import FusedMoE +from sglang.srt.layers.moe import get_moe_a2a_backend +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher from sglang.srt.layers.moe.topk import TopK +from sglang.srt.layers.moe.utils import DeepEPMode from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix, make_layers +from sglang.srt.utils import add_prefix, is_cuda, is_non_idle_and_non_empty, make_layers + +LoraConfig = None +logger = logging.getLogger(__name__) +_is_cuda = is_cuda() -class BailingAttention(nn.Module): +class BailingMoEMLP(nn.Module): + def __init__( + self, + intermediate_size: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: Optional[bool] = True, + prefix: str = "", + tp_rank: Optional[int] = None, + tp_size: Optional[int] = None, + ) -> None: + super().__init__() + self.tp_size = tp_size + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, + [intermediate_size] * 2, + bias=config.use_bias, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + tp_rank=tp_rank, + tp_size=tp_size, + ) + self.down_proj = RowParallelLinear( + intermediate_size, + config.hidden_size, + bias=config.use_bias, + reduce_results=reduce_results, + quant_config=quant_config, + prefix=add_prefix("down_proj", prefix), + tp_rank=tp_rank, + tp_size=tp_size, + ) + + if config.hidden_act != "silu": + raise ValueError("Unsupported activation. Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward( + self, + hidden_states: torch.Tensor, + forward_batch: Optional[ForwardBatch] = None, + use_reduce_scatter: bool = False, + ) -> torch.Tensor: + if (self.tp_size == 1) and hidden_states.shape[0] == 0: + return hidden_states + + gate_up, _ = self.gate_up_proj(hidden_states) + hidden_states = self.act_fn(gate_up) + hidden_states, _ = self.down_proj(hidden_states) + return hidden_states + + +class BailingMoEGate(nn.Module): + def __init__( + self, + config, + params_dtype: Optional[torch.dtype] = None, + prefix: str = "", + ): + super().__init__() + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + self.weight = nn.Parameter( + torch.empty( + (config.num_experts, config.hidden_size), + dtype=self.params_dtype, + ), + ) + if getattr(config, "moe_router_enable_expert_bias", False): + self.expert_bias = nn.Parameter( + torch.empty((config.num_experts,), dtype=torch.float32), + ) + else: + self.expert_bias = None + + def forward(self, hidden_states): + logits = F.linear(hidden_states.to(self.weight.dtype), self.weight, None).to( + hidden_states.dtype + ) + return logits + + +class BailingMoESparseMoeBlock(nn.Module): + def __init__( + self, + layer_id: int, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + alt_stream: Optional[torch.cuda.Stream] = None, + prefix: str = "", + ): + super().__init__() + self.layer_id = layer_id + self.alt_stream = alt_stream + self.tp_size = get_tensor_model_parallel_world_size() + self.top_k = config.num_experts_per_tok + self.norm_topk_prob = config.norm_topk_prob + self.hidden_size = config.hidden_size + self.num_shared_experts = config.num_shared_experts + self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) + self.score_function = getattr(config, "score_function", None) + + if config.hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + # Gate always runs at half / full precision for now. + router_dtype = getattr(config, "router_dtype", None) + if router_dtype is None: + self.router_dtype = None + elif router_dtype == "fp32": + self.router_dtype = torch.float32 + else: + self.router_dtype = torch.bfloat16 + + # TODO global_server_args_dict["ep_num_redundant_experts"] is used for eplb, not supported now + assert global_server_args_dict["ep_num_redundant_experts"] == 0 + # check group topk + self.num_expert_group = getattr(config, "n_group", 0) + self.topk_group = getattr(config, "topk_group", 0) + if self.num_expert_group > 0 or self.topk_group > 0: + assert ( + self.num_expert_group > 0 + and 0 < self.topk_group <= self.num_expert_group + ) + self.use_grouped_topk = True + else: + self.num_expert_group = self.topk_group = None + self.use_grouped_topk = False + + self.num_experts = ( + config.num_experts + global_server_args_dict["ep_num_redundant_experts"] + ) + + self.gate = BailingMoEGate( + config=config, + params_dtype=self.router_dtype, + prefix=add_prefix("gate", prefix), + ) + self.correction_bias = ( + self.gate.expert_bias.data if self.gate.expert_bias is not None else None + ) + + if self.score_function is not None: + assert ( + self.score_function == "softmax" and self.correction_bias is None + ) or ( + self.score_function == "sigmoid" and self.correction_bias is not None + ), "score_function and correction_bias should be in 2 combination (softmax, None) or (sigmoid, not None)" + + self.topk = TopK( + top_k=self.top_k, + renormalize=self.norm_topk_prob, + use_grouped_topk=self.use_grouped_topk, + num_expert_group=self.num_expert_group, + # num_fused_shared_experts=self.num_fused_shared_experts, + topk_group=self.topk_group, + correction_bias=self.correction_bias, + routed_scaling_factor=self.routed_scaling_factor, + ) + + self.experts = get_moe_impl_class()( + num_experts=self.num_experts, + top_k=self.top_k, + layer_id=self.layer_id, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + quant_config=quant_config, + routed_scaling_factor=self.routed_scaling_factor, + prefix=add_prefix("experts", prefix), + ) + # shared expert + if config.num_shared_experts is not None: + if hasattr(config, "moe_shared_expert_intermediate_size"): + intermediate_size = config.moe_shared_expert_intermediate_size + else: + intermediate_size = config.moe_intermediate_size + intermediate_size *= config.num_shared_experts + # disable tp for shared experts when enable deepep moe + self.shared_experts = BailingMoEMLP( + intermediate_size=intermediate_size, + config=config, + quant_config=quant_config, + reduce_results=False, + prefix=add_prefix("shared_experts", prefix), + **( + dict(tp_rank=0, tp_size=1) + if get_moe_a2a_backend().is_deepep() + else {} + ), + ) + # dispatcher + if get_moe_a2a_backend().is_deepep(): + # TODO: we will support tp < ep in the future + self.ep_size = get_tensor_model_parallel_world_size() + + self.deepep_dispatcher = DeepEPDispatcher( + group=parallel_state.get_tp_group().device_group, + router_topk=self.top_k, + permute_fusion=True, + num_experts=self.num_experts, + num_local_experts=config.num_experts // self.tp_size, + hidden_size=config.hidden_size, + params_dtype=config.torch_dtype, + deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]], + async_finish=True, # TODO + return_recv_hook=True, + ) + + def forward( + self, + hidden_states: torch.Tensor, + forward_batch: Optional[ForwardBatch] = None, + use_reduce_scatter: bool = False, + ) -> torch.Tensor: + if not get_moe_a2a_backend().is_deepep(): + return self.forward_normal(hidden_states, use_reduce_scatter) + else: + return self.forward_deepep(hidden_states, forward_batch) + + def get_moe_weights(self): + return [ + x.data + for name, x in self.experts.named_parameters() + if name not in ["correction_bias"] + ] + + def _forward_shared_experts(self, hidden_states: torch.Tensor): + shared_output = None + if self.num_shared_experts > 0: + shared_output = self.shared_experts(hidden_states) + return shared_output + + def _forward_router_experts(self, hidden_states: torch.Tensor): + # router_logits: (num_tokens, n_experts) + router_logits = self.gate(hidden_states) + topk_output = self.topk(hidden_states, router_logits) + return self.experts(hidden_states, topk_output) + + def forward_normal_dual_stream( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + shared_output = self._forward_shared_experts(hidden_states) + + with torch.cuda.stream(self.alt_stream): + router_output = self._forward_router_experts(hidden_states) + current_stream.wait_stream(self.alt_stream) + + return router_output, shared_output + + def forward_normal( + self, + hidden_states: torch.Tensor, + use_reduce_scatter: bool = False, + ) -> torch.Tensor: + num_tokens, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_size) + + DUAL_STREAM_TOKEN_THRESHOLD = 1024 + if ( + self.alt_stream is not None + and num_tokens > 0 + and num_tokens <= DUAL_STREAM_TOKEN_THRESHOLD + ): + final_hidden_states, shared_output = self.forward_normal_dual_stream( + hidden_states + ) + else: + shared_output = self._forward_shared_experts(hidden_states) + final_hidden_states = self._forward_router_experts(hidden_states) + + if self.num_shared_experts > 0: + final_hidden_states = final_hidden_states + shared_output + + if self.tp_size > 1 and not use_reduce_scatter: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + return final_hidden_states.view(num_tokens, hidden_size) + + def forward_deepep( + self, hidden_states: torch.Tensor, forward_batch: ForwardBatch + ) -> torch.Tensor: + shared_output = None + forward_mode = forward_batch.forward_mode + if is_non_idle_and_non_empty(forward_mode, hidden_states): + router_logits = self.gate(hidden_states) + if self.num_shared_experts > 0: + shared_output = self.shared_experts(hidden_states) + + topk_weights, topk_idx, _ = self.topk( + hidden_states, + router_logits, + num_token_non_padded=forward_batch.num_token_non_padded, + expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new( + layer_id=self.layer_id, + ), + ) + else: + topk_idx = torch.full( + (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device + ) + topk_weights = torch.empty( + (0, self.top_k), dtype=torch.float32, device=hidden_states.device + ) + + if self.ep_size > 1: + ( + hidden_states, + topk_idx, + topk_weights, + reorder_topk_ids, + num_recv_tokens_per_expert, + seg_indptr, + masked_m, + expected_m, + ) = self.deepep_dispatcher.dispatch( + hidden_states, + topk_idx, + topk_weights, + forward_batch=forward_batch, + ) + + final_hidden_states = self.experts( + hidden_states=hidden_states, + topk_idx=topk_idx, + topk_weights=topk_weights, + reorder_topk_ids=reorder_topk_ids, + seg_indptr=seg_indptr, + masked_m=masked_m, + expected_m=expected_m, + num_recv_tokens_per_expert=num_recv_tokens_per_expert, + forward_batch=forward_batch, + ) + if self.ep_size > 1: + final_hidden_states = self.deepep_dispatcher.combine( + final_hidden_states, + topk_idx, + topk_weights, + forward_batch=forward_batch, + ) + + final_hidden_states *= self.routed_scaling_factor + + if shared_output is not None: + final_hidden_states = final_hidden_states + shared_output + return final_hidden_states + + +class BailingMoEAttention(nn.Module): def __init__( self, config: PretrainedConfig, layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ): super().__init__() self.hidden_size = config.hidden_size - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = config.num_attention_heads - self.total_num_kv_heads = config.num_key_value_heads + self.total_kv_heads = config.num_key_value_heads + self.dp_size = get_attention_dp_size() + attn_tp_rank = get_attention_tp_rank() + attn_tp_size = get_attention_tp_size() - assert self.total_num_heads % tp_size == 0 - assert self.total_num_kv_heads % tp_size == 0 + assert self.total_num_heads % attn_tp_size == 0 + assert self.total_kv_heads % attn_tp_size == 0 + assert self.total_num_heads >= self.total_kv_heads - self.num_heads = self.total_num_heads // tp_size + self.num_heads = self.total_num_heads // attn_tp_size self.head_dim = config.head_dim or (self.hidden_size // self.total_num_heads) - self.q_size = self.num_heads * self.head_dim + self.q_size = self.head_dim * self.num_heads + + self.num_kv_heads = self.total_kv_heads // attn_tp_size + self.kv_size = max(1, self.num_kv_heads * self.head_dim) - self.num_kv_heads = self.total_num_kv_heads // tp_size - self.kv_size = self.num_kv_heads * self.head_dim self.scale = self.head_dim**-0.5 + self.use_qk_norm = getattr(config, "use_qk_norm", False) + self.query_key_value = QKVParallelLinear( self.hidden_size, self.head_dim, self.total_num_heads, - self.total_num_kv_heads, + self.total_kv_heads, bias=(config.use_bias or config.use_qkv_bias), quant_config=quant_config, prefix=add_prefix("query_key_value", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, ) + if self.use_qk_norm: + self.query_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.key_layernorm = RMSNorm(self.head_dim, eps=config.rms_norm_eps) + self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=config.use_bias, quant_config=quant_config, + reduce_results=reduce_results, prefix=add_prefix("dense", prefix), + tp_rank=attn_tp_rank, + tp_size=attn_tp_size, + ) + + if hasattr(config, "partial_rotary_factor"): + self.rotary_dim = int(self.head_dim * config.partial_rotary_factor) + elif hasattr(config, "rotary_dim"): + self.rotary_dim = config.rotary_dim + else: + self.rotary_dim = self.head_dim + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_dim, + max_position=config.max_position_embeddings, + base=config.rope_theta, + rope_scaling=config.rope_scaling, ) self.attn = RadixAttention( @@ -87,291 +513,369 @@ class BailingAttention(nn.Module): self.scale, num_kv_heads=self.num_kv_heads, layer_id=layer_id, - quant_config=quant_config, prefix=add_prefix("attn", prefix), ) - self.rotary_emb = get_rope( - self.head_dim, - rotary_dim=self.head_dim, - max_position=config.max_position_embeddings, - base=config.rope_theta, - is_neox_style=True, - rope_scaling=config.rope_scaling, - ) + self.alt_stream = alt_stream + + def _apply_qk_norm( + self, q: torch.Tensor, k: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + # overlap qk norm + if self.alt_stream is not None and get_is_capture_mode(): + current_stream = torch.cuda.current_stream() + self.alt_stream.wait_stream(current_stream) + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.query_layernorm(q_by_head) + with torch.cuda.stream(self.alt_stream): + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.key_layernorm(k_by_head) + current_stream.wait_stream(self.alt_stream) + else: + q_by_head = q.reshape(-1, self.head_dim) + q_by_head = self.query_layernorm(q_by_head) + k_by_head = k.reshape(-1, self.head_dim) + k_by_head = self.key_layernorm(k_by_head) + q = q_by_head.view(q.shape) + k = k_by_head.view(k.shape) + return q, k def forward( self, + positions: torch.Tensor, hidden_states: torch.Tensor, - position_ids: torch.Tensor, forward_batch: ForwardBatch, ) -> torch.Tensor: + if hidden_states.shape[0] == 0: + return hidden_states qkv, _ = self.query_key_value(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - - q, k = self.rotary_emb(position_ids, q, k) + if self.use_qk_norm: + q, k = self._apply_qk_norm(q, k) + q, k = self.rotary_emb(positions, q, k) context_layer = self.attn(q, k, v, forward_batch) attn_output, _ = self.dense(context_layer) return attn_output -class BailingMLP(nn.Module): - def __init__( - self, - intermediate_size: int, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - reduce_results: Optional[bool] = True, - prefix: str = "", - ) -> None: - super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - config.hidden_size, - [intermediate_size] * 2, - bias=config.use_bias, - quant_config=quant_config, - prefix=add_prefix("gate_up_proj", prefix), - ) - self.down_proj = RowParallelLinear( - intermediate_size, - config.hidden_size, - bias=config.use_bias, - quant_config=quant_config, - reduce_results=reduce_results, - prefix=add_prefix("down_proj", prefix), - ) - self.act_fn = SiluAndMul() - - def forward(self, x): - x, _ = self.gate_up_proj(x) - x = self.act_fn(x) - x, _ = self.down_proj(x) - return x - - -class BailingMoE(nn.Module): - +class BailingMoEBlock(nn.Module): def __init__( self, config: PretrainedConfig, - layer_id: int, + layer_id: int = 0, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, ): super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - self.num_experts = config.num_experts - self.top_k = config.num_experts_per_tok - self.hidden_size = config.hidden_size - self.num_shared_experts = config.num_shared_experts - self.norm_expert_prob = config.norm_topk_prob - self.moe_intermediate_size = config.moe_intermediate_size + hidden_size = config.hidden_size - self.gate = ReplicatedLinear( - self.hidden_size, self.num_experts, bias=False, quant_config=None - ) - - self.topk = TopK(top_k=self.top_k, renormalize=self.norm_expert_prob) - - self.experts = FusedMoE( - num_experts=self.num_experts, - top_k=self.top_k, - layer_id=layer_id, - hidden_size=self.hidden_size, - intermediate_size=self.moe_intermediate_size, + self.input_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) + self.dp_size = get_attention_dp_size() + self.attention = BailingMoEAttention( + config, + layer_id, + quant_config, reduce_results=False, - quant_config=quant_config, - prefix=add_prefix("experts", prefix), + prefix=add_prefix("attention", prefix), + alt_stream=alt_stream, + ) + self.layer_id = layer_id + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() + + self.is_layer_sparse = self._is_layer_sparse( + config, layer_id=layer_id, is_nextn=False + ) + is_previous_layer_sparse = self._is_layer_sparse( + config, layer_id=layer_id - 1, is_nextn=False ) - if self.num_shared_experts > 0: - shared_intermediate_size = ( - self.moe_intermediate_size * self.num_shared_experts - ) - self.shared_experts = BailingMLP( - intermediate_size=shared_intermediate_size, + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=self.is_layer_sparse, + is_previous_layer_sparse=is_previous_layer_sparse, + ) + + self.is_last_layer = self.layer_id == config.num_hidden_layers - 1 + + if self.is_layer_sparse: + self.mlp = BailingMoESparseMoeBlock( + layer_id=layer_id, config=config, quant_config=quant_config, - reduce_results=False, - prefix=add_prefix("shared_experts", prefix), + alt_stream=alt_stream, + prefix=add_prefix("mlp", prefix), ) else: - self.shared_experts = None + if enable_moe_dense_fully_dp(): + mlp_tp_rank, mlp_tp_size = 0, 1 + else: + mlp_tp_rank, mlp_tp_size = None, None + self.mlp = BailingMoEMLP( + intermediate_size=config.intermediate_size, + config=config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + tp_rank=mlp_tp_rank, + tp_size=mlp_tp_size, + ) - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - orig_shape = hidden_states.shape - hidden_states_flat = hidden_states.view(-1, self.hidden_size) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=config.rms_norm_eps) - shared_output = None - if self.shared_experts is not None: - shared_output = self.shared_experts(hidden_states_flat) - - router_logits, _ = self.gate(hidden_states_flat) - topk_output = self.topk(hidden_states_flat, router_logits) - final_hidden_states = self.experts(hidden_states_flat, topk_output) - - if shared_output is not None: - final_hidden_states = final_hidden_states + shared_output - - if self.tp_size > 1: - final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) - - return final_hidden_states.view(orig_shape) - - -class BailingMoeBlock(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - layer_id: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.attention = BailingAttention( - config, layer_id, quant_config, prefix=add_prefix("attention", prefix) + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, + allow_reduce_scatter=True, ) - self.post_attention_layernorm = RMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.mlp = BailingMoE( - config=config, - layer_id=layer_id, - quant_config=quant_config, - prefix=add_prefix("mlp", prefix), + + def _is_layer_sparse( + self, config: PretrainedConfig, layer_id: int, is_nextn: bool + ) -> bool: + return is_nextn or ( + config.num_experts is not None and layer_id >= config.first_k_dense_replace ) def forward( self, + positions: torch.Tensor, hidden_states: torch.Tensor, - position_ids: torch.Tensor, - residual: Optional[torch.Tensor], forward_batch: ForwardBatch, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # Pre-normalization and residual connection for the attention block - if residual is None: - residual = hidden_states - normed_hidden_states = self.input_layernorm(hidden_states) - else: - normed_hidden_states, residual = self.input_layernorm( - hidden_states, residual - ) - - attn_output = self.attention( - hidden_states=normed_hidden_states, - position_ids=position_ids, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states=hidden_states, + residual=residual, forward_batch=forward_batch, ) - # Pre-normalization and residual connection for the MLP block - normed_hidden_states, residual = self.post_attention_layernorm( - attn_output, residual + hidden_states = self.attention( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, ) - mlp_output = self.mlp(normed_hidden_states) - return mlp_output, residual + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + ) + + # For DP with padding, reduce scatter can be used instead of all-reduce. + use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter( + forward_batch + ) + + hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter) + + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + ) + + return hidden_states, residual -class BailingMoeModel(nn.Module): +class BailingMoEModel(nn.Module): def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + alt_stream: Optional[torch.cuda.Stream] = None, prefix: str = "", ): super().__init__() + self.pp_group = get_pp_group() self.config = config - self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_dim = config.hidden_size + if self.pp_group.is_first_rank: + self.word_embeddings = VocabParallelEmbedding( + self.vocab_size, + self.embed_dim, + quant_config=quant_config, + prefix=add_prefix("word_embeddings", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + ) + else: + self.word_embeddings = PPMissingLayer() - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - prefix=add_prefix("embed_tokens", prefix), - ) self.embedding_dropout = torch.nn.Dropout(config.embedding_dropout) - self.layers = make_layers( + self.layers, self.start_layer, self.end_layer = make_layers( config.num_hidden_layers, - lambda idx, prefix: BailingMoeBlock( - config=config, + lambda idx, prefix: BailingMoEBlock( layer_id=idx, + config=config, quant_config=quant_config, prefix=prefix, + alt_stream=alt_stream, ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, prefix=add_prefix("layers", prefix), ) - - self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) - - def forward( - self, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - forward_batch: ForwardBatch, - input_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if input_embeds is None: - hidden_states = self.embed_tokens(input_ids) + if self.pp_group.is_last_rank: + self.norm = RMSNorm(self.embed_dim, eps=config.rms_norm_eps) else: - hidden_states = input_embeds - - residual = None - for layer in self.layers: - hidden_states, residual = layer( - hidden_states, - position_ids, - residual, - forward_batch, - ) - - hidden_states, _ = self.norm(hidden_states, residual) - return hidden_states - - -class BailingMoeForCausalLM(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - super().__init__() - self.config = config - self.model = BailingMoeModel(config=config, quant_config=quant_config) - self.lm_head = ParallelLMHead( - num_embeddings=config.vocab_size, - embedding_dim=config.hidden_size, - quant_config=quant_config, - ) - if config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - - self.logits_processor = LogitsProcessor(config) + self.norm = PPMissingLayer(return_tuple=True) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, forward_batch: ForwardBatch, - inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds) - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch + input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[torch.Tensor, PPProxyTensors]: + if self.pp_group.is_first_rank: + if input_embeds is None: + hidden_states = self.word_embeddings(input_ids) + else: + hidden_states = input_embeds + residual = None + else: + assert pp_proxy_tensors is not None + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + with get_global_expert_distribution_recorder().with_current_layer(i): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + forward_batch, + residual, + ) + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + if not forward_batch.forward_mode.is_idle(): + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class BailingMoEForCausalLM(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.pp_group = get_pp_group() + self.config = config + self.quant_config = quant_config + alt_stream = torch.cuda.Stream() if _is_cuda else None + + self.model = BailingMoEModel( + config, + quant_config, + alt_stream=alt_stream, + prefix=add_prefix("model", ""), ) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # tie_word_embeddings为true,复用tie_word_embeddings,反之是独立的 + if config.tie_word_embeddings: + self.lm_head = self.model.word_embeddings + else: + # TODO something wrong with ParallelLMHead with DP attention enabled + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + ) + self.logits_processor = LogitsProcessor(config) + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer + + def get_embed_and_head(self): + """Used by the eagle_worker.""" + return self.model.word_embeddings.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + """Used by the eagle_worker.""" + del self.model.word_embeddings.weight + del self.lm_head.weight + self.model.word_embeddings.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, + ) + if self.pp_group.is_last_rank: + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + else: + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False): + if is_nextn: + if hasattr(self.config, "num_nextn_predict_layers"): + num_nextn_layers = self.config.num_nextn_predict_layers + assert num_nextn_layers == 1, "Only 1 nextn layer is supported" + # compatible with old design + nextn_layer_id = ( + 0 + if self.config.num_hidden_layers == 1 + else self.config.num_hidden_layers + ) + else: + raise ValueError("num_nextn_predict_layers is not in the config") stacked_params_mapping = [ + # (param_name, shard_name, shard_id) ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] + if is_nextn: + nextn_layer_prefix = f"model.layers.{nextn_layer_id}" + nextn_spec_weight_names = [ + "final_layernorm", + "eh_proj", + "enorm", + "hnorm", + ] + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", @@ -381,39 +885,87 @@ class BailingMoeForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + if ( + ("v_head" in name) + or ("inv_freq" in name) + or (self.config.tie_word_embeddings and "lm_head" in name) + ): + continue if ( hasattr(self.config, "norm_head") and self.config.norm_head and "lm_head.weight" in name ): + import torch.nn.functional as F + loaded_weight = F.normalize(loaded_weight, dim=0, p=2, eps=1e-7) - if "model.word_embeddings.weight" == name: - name = "model.embed_tokens.weight" + if is_nextn: + if not name.startswith(nextn_layer_prefix): + continue + + # Use shared head and embed weights from target model + if "shared_head.head" in name or "embed_tokens" in name: + continue + + is_decoder = True + # For nextn specific weights + for weight_name in nextn_spec_weight_names: + if weight_name in name: + name = name.replace(nextn_layer_prefix, "model") + is_decoder = False + break + # For decoder layer weights + if is_decoder: + name = name.replace(nextn_layer_prefix, "model.decoder") for param_name, weight_name, shard_id in stacked_params_mapping: - if weight_name in name and "mlp.experts" not in name: - full_param_name = name.replace(weight_name, param_name) - param = params_dict[full_param_name] - param.weight_loader(param, loaded_weight, shard_id) - break + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break else: - for p_name, w_name, e_id, s_id in expert_params_mapping: - if w_name in name and "mlp.experts" in name: - full_param_name = name.replace(w_name, p_name) - param = params_dict[full_param_name] - param.weight_loader( - param, - loaded_weight, - full_param_name, - shard_id=s_id, - expert_id=e_id, - ) - break + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break else: + # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if name not in params_dict: + continue param = params_dict[name] weight_loader = getattr( @@ -421,5 +973,30 @@ class BailingMoeForCausalLM(nn.Module): ) weight_loader(param, loaded_weight) + if not is_nextn: + self.routed_experts_weights_of_layer = { + layer_id: layer.mlp.get_moe_weights() + for layer_id, layer in enumerate(self.model.layers) + if not isinstance(layer, PPMissingLayer) + and isinstance(layer.mlp, BailingMoESparseMoeBlock) + } -EntryClass = BailingMoeForCausalLM + @classmethod + def get_model_config_for_expert_location(cls, config): + num_groups = getattr(config, "n_group", 0) + return ModelConfigForExpertLocation( + num_layers=config.num_hidden_layers, + num_logical_experts=config.num_experts, + num_groups=None if num_groups == 0 else num_groups, + ) + + +class BailingMoeForCausalLM(BailingMoEForCausalLM): + pass + + +class BailingMoeV2ForCausalLM(BailingMoEForCausalLM): + pass + + +EntryClass = [BailingMoEForCausalLM, BailingMoeForCausalLM, BailingMoeV2ForCausalLM] diff --git a/python/sglang/srt/models/bailing_moe_nextn.py b/python/sglang/srt/models/bailing_moe_nextn.py new file mode 100644 index 000000000..49198001c --- /dev/null +++ b/python/sglang/srt/models/bailing_moe_nextn.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" SGLang BailingMoENextN model.""" +import logging +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.dp_attention import is_dp_attention_enabled +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.topk import select_experts +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM +from sglang.srt.utils import add_prefix + +LoraConfig = None +logger = logging.getLogger(__name__) + + +class BailingMoEModelNextN(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + if quant_config is not None and quant_config.get_name() == "modelopt_fp4": + logger.warning( + "Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model." + ) + quant_config = None + + self.vocab_size = config.vocab_size + + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("word_embeddings", prefix), + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False) + + self.decoder = BailingMoEBlock( + config, + 0, + quant_config=quant_config, + # is_nextn=True, + prefix=add_prefix("decoder", prefix), + ) + + self.shared_head = nn.Module() + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + + if input_embeds is None: + hidden_states = self.word_embeddings(input_ids) + else: + hidden_states = input_embeds + + if hidden_states.shape[0] > 0: + hidden_states = self.eh_proj( + torch.cat( + ( + self.enorm(hidden_states), + self.hnorm(forward_batch.spec_info.hidden_states), + ), + dim=-1, + ) + ) + + residual = None + hidden_states, residual = self.decoder( + positions, hidden_states, forward_batch, residual + ) + + if not forward_batch.forward_mode.is_idle(): + if residual is not None: + hidden_states, _ = self.final_layernorm(hidden_states, residual) + else: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class BailingMoeForCausalLMNextN(BailingMoEForCausalLM): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.quant_config = quant_config + if hasattr(self, "determine_num_fused_shared_experts"): + # Asystem has determine_num_fused_shared_experts but theta does not. + self.determine_num_fused_shared_experts("BailingMoeForCausalLMNextN") + + self.model = BailingMoEModelNextN( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("model.shared_head.head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, forward_batch) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + super().load_weights(weights, is_nextn=True) + + +EntryClass = [BailingMoeForCausalLMNextN] diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 363b25f46..88fe9d6fe 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -754,7 +754,12 @@ class ServerArgs: ) model_arch = self.get_hf_config().architectures[0] - if model_arch in ["DeepseekV3ForCausalLM", "Glm4MoeForCausalLM"]: + if model_arch in [ + "DeepseekV3ForCausalLM", + "Glm4MoeForCausalLM", + "BailingMoeV2ForCausalLM", + "BailingMoeV2ForCausalLM", + ]: # Auto set draft_model_path DeepSeek-V3/R1 if self.speculative_draft_model_path is None: self.speculative_draft_model_path = self.model_path @@ -2724,6 +2729,8 @@ def auto_choose_speculative_params(self: ServerArgs): "DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM", "GptOssForCausalLM", + "BailingMoeForCausalLM", + "BailingMoeV2ForCausalLM", ]: # The default value for deepseek and gpt-oss return (3, 1, 4)