From 5e194b21437fed6fa1f4c8ef11ccd61b34bc0607 Mon Sep 17 00:00:00 2001 From: Guoyuan Lin Date: Sun, 31 Aug 2025 14:29:21 +0800 Subject: [PATCH] [Model] Support Meituan LongCat-Flash && LongCat-Flash-MTP (#9824) --- python/sglang/srt/configs/__init__.py | 2 + python/sglang/srt/configs/longcat_flash.py | 104 ++ python/sglang/srt/configs/model_config.py | 12 + python/sglang/srt/hf_transformers_utils.py | 2 + .../sglang/srt/layers/moe/ep_moe/kernels.py | 74 ++ python/sglang/srt/layers/moe/topk.py | 33 +- .../sglang/srt/layers/quantization/utils.py | 13 + .../sglang/srt/model_executor/model_runner.py | 5 +- python/sglang/srt/models/longcat_flash.py | 1015 +++++++++++++++++ .../sglang/srt/models/longcat_flash_nextn.py | 691 +++++++++++ 10 files changed, 1940 insertions(+), 11 deletions(-) create mode 100644 python/sglang/srt/configs/longcat_flash.py create mode 100644 python/sglang/srt/models/longcat_flash.py create mode 100644 python/sglang/srt/models/longcat_flash_nextn.py diff --git a/python/sglang/srt/configs/__init__.py b/python/sglang/srt/configs/__init__.py index 9c3008572..24fba32b3 100644 --- a/python/sglang/srt/configs/__init__.py +++ b/python/sglang/srt/configs/__init__.py @@ -5,6 +5,7 @@ from sglang.srt.configs.exaone import ExaoneConfig from sglang.srt.configs.janus_pro import MultiModalityConfig from sglang.srt.configs.kimi_vl import KimiVLConfig from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig +from sglang.srt.configs.longcat_flash import LongcatFlashConfig from sglang.srt.configs.step3_vl import ( Step3TextConfig, Step3VisionEncoderConfig, @@ -16,6 +17,7 @@ __all__ = [ "ChatGLMConfig", "DbrxConfig", "DeepseekVL2Config", + "LongcatFlashConfig", "MultiModalityConfig", "KimiVLConfig", "MoonViTConfig", diff --git a/python/sglang/srt/configs/longcat_flash.py b/python/sglang/srt/configs/longcat_flash.py new file mode 100644 index 000000000..e6a2dfb02 --- /dev/null +++ b/python/sglang/srt/configs/longcat_flash.py @@ -0,0 +1,104 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +FLASH_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + +class LongcatFlashConfig(PretrainedConfig): + model_type = "longcat_flash" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=131072, + hidden_size=6144, + intermediate_size=None, + ffn_hidden_size=12288, + expert_ffn_hidden_size=2048, + num_layers=28, + num_hidden_layers=None, + num_attention_heads=64, + ep_size=1, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=128, + qk_nope_head_dim=128, + v_head_dim=128, + n_routed_experts=512, + moe_topk=12, + norm_topk_prob=False, + max_position_embeddings=131072, + rms_norm_eps=1e-05, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mla_scale_q_lora=True, + mla_scale_kv_lora=True, + torch_dtype="bfloat16", + params_dtype="bfloat16", + rounter_params_dtype="float32", + router_bias=False, + topk_method=None, + routed_scaling_factor=6.0, + zero_expert_num=256, + zero_expert_type="identity", + nextn_use_scmoe=False, + num_nextn_predict_layers=1, + **kwargs, + ): + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + torch_dtype=torch_dtype, + params_dtype=params_dtype, + rounter_params_dtype=rounter_params_dtype, + topk_method=topk_method, + router_bias=router_bias, + nextn_use_scmoe=nextn_use_scmoe, + num_nextn_predict_layers=num_nextn_predict_layers, + **kwargs, + ) + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_hidden_layers = ( + num_hidden_layers if num_hidden_layers is not None else num_layers + ) + self.intermediate_size = ( + intermediate_size if intermediate_size is not None else ffn_hidden_size + ) + self.moe_intermediate_size = expert_ffn_hidden_size + self.num_attention_heads = num_attention_heads + self.ep_size = ep_size + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.n_routed_experts = n_routed_experts + self.moe_topk = moe_topk + self.norm_topk_prob = norm_topk_prob + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mla_scale_q_lora = mla_scale_q_lora + self.mla_scale_kv_lora = mla_scale_kv_lora + self.zero_expert_num = zero_expert_num + self.zero_expert_type = zero_expert_type + self.routed_scaling_factor = routed_scaling_factor + self.hidden_act = "silu" diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index 3b3fef5c8..8fb00972e 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -132,6 +132,13 @@ class ModelConfig: if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM": self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN" + if ( + is_draft_model + and self.hf_config.architectures[0] == "LongcatFlashForCausalLM" + ): + self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN" + self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers + if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM": self.hf_config.architectures[0] = "MiMoMTP" if ( @@ -199,6 +206,8 @@ class ModelConfig: "DeepseekV2ForCausalLM" in self.hf_config.architectures or "DeepseekV3ForCausalLM" in self.hf_config.architectures or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures + or "LongcatFlashForCausalLM" in self.hf_config.architectures + or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures ): self.head_dim = 256 self.attention_arch = AttentionArch.MLA @@ -270,6 +279,9 @@ class ModelConfig: self.num_key_value_heads = self.num_attention_heads self.hidden_size = self.hf_text_config.hidden_size self.num_hidden_layers = self.hf_text_config.num_hidden_layers + self.num_attention_layers = self.num_hidden_layers + if "LongcatFlashForCausalLM" in self.hf_config.architectures: + self.num_attention_layers = self.num_hidden_layers * 2 self.num_nextn_predict_layers = getattr( self.hf_text_config, "num_nextn_predict_layers", None ) diff --git a/python/sglang/srt/hf_transformers_utils.py b/python/sglang/srt/hf_transformers_utils.py index 0edfa92ae..2f500ae79 100644 --- a/python/sglang/srt/hf_transformers_utils.py +++ b/python/sglang/srt/hf_transformers_utils.py @@ -40,6 +40,7 @@ from sglang.srt.configs import ( DeepseekVL2Config, ExaoneConfig, KimiVLConfig, + LongcatFlashConfig, MultiModalityConfig, Step3VLConfig, ) @@ -56,6 +57,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { KimiVLConfig.model_type: KimiVLConfig, InternVLChatConfig.model_type: InternVLChatConfig, Step3VLConfig.model_type: Step3VLConfig, + LongcatFlashConfig.model_type: LongcatFlashConfig, } for name, cls in _CONFIG_REGISTRY.items(): diff --git a/python/sglang/srt/layers/moe/ep_moe/kernels.py b/python/sglang/srt/layers/moe/ep_moe/kernels.py index f1649d5c9..bea38cc41 100644 --- a/python/sglang/srt/layers/moe/ep_moe/kernels.py +++ b/python/sglang/srt/layers/moe/ep_moe/kernels.py @@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess( gateup_input, gateup_input_scale, ) + + +@triton.jit +def compute_identity_kernel( + top_k, + hidden_states_ptr, + expert_scales_ptr, + num_tokens, + output_ptr, + hidden_dim, + scales_stride, + BLOCK_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + + batch_id = pid // (hidden_dim // BLOCK_SIZE) + dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE + + if batch_id >= num_tokens or dim_offset >= hidden_dim: + return + + h = tl.load( + hidden_states_ptr + + batch_id * hidden_dim + + dim_offset + + tl.arange(0, BLOCK_SIZE), + mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim, + ) + + result = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for i in range(top_k): + scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i) + result += h * scale + + tl.store( + output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE), + result, + mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim, + ) + + +def zero_experts_compute_triton( + expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states +): + N = expert_indices.numel() + top_k = expert_indices.size(-1) + grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),) + + if zero_expert_type == "identity": + zero_expert_mask = expert_indices < num_experts + zero_expert_scales = expert_scales.clone() + zero_expert_scales[zero_expert_mask] = 0.0 + + normal_expert_mask = expert_indices >= num_experts + expert_indices[normal_expert_mask] = 0 + expert_scales[normal_expert_mask] = 0.0 + + output = torch.zeros_like(hidden_states).to(hidden_states.device) + hidden_dim = hidden_states.size(-1) + num_tokens = hidden_states.size(0) + + grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),) + compute_identity_kernel[grid]( + top_k, + hidden_states, + zero_expert_scales, + num_tokens, + output, + hidden_dim, + zero_expert_scales.stride(0), + BLOCK_SIZE=256, + ) + + return output diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 7e43a5541..a0cea08d6 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -357,17 +357,28 @@ def fused_topk_torch_native( gating_output: torch.Tensor, topk: int, renormalize: bool, + correction_bias: torch.Tensor = None, ): - assert ( - hidden_states.shape[0] == gating_output.shape[0] - ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}" - M, _ = hidden_states.shape - topk_weights = torch.empty( - M, topk, dtype=torch.float32, device=hidden_states.device - ) - topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) - topk_weights = F.softmax(gating_output.float(), dim=-1) - topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) + if correction_bias is not None: + n_routed_experts = gating_output.shape[-1] + scores = gating_output.softmax(dim=-1) + scores_for_choice = scores.view( + -1, n_routed_experts + ) + correction_bias.unsqueeze(0) + topk_ids = torch.topk(scores_for_choice, k=topk, dim=-1, sorted=False)[1] + topk_weights = scores.gather(1, topk_ids) + else: + assert ( + hidden_states.shape[0] == gating_output.shape[0] + ), f"Number of tokens mismatch, {hidden_states.shape=} vs {gating_output.shape=}" + M, _ = hidden_states.shape + topk_weights = torch.empty( + M, topk, dtype=torch.float32, device=hidden_states.device + ) + topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device) + topk_weights = F.softmax(gating_output.float(), dim=-1) + topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1) + if renormalize: topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) return topk_weights, topk_ids @@ -380,6 +391,7 @@ def fused_topk_cpu( renormalize: bool, num_token_non_padded: Optional[torch.Tensor] = None, expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None, + correction_bias: torch.Tensor = None, ): topk_weights, topk_ids = torch.ops.sgl_kernel.topk_softmax_cpu( hidden_states=hidden_states, @@ -825,6 +837,7 @@ def select_experts( gating_output=router_logits, topk=top_k, renormalize=renormalize, + correction_bias=correction_bias, ) elif custom_routing_function is None: assert not apply_routed_scaling_factor_on_output, "Not implemented" diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index df434ae0a..63b8b6eb7 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -77,6 +77,19 @@ def is_layer_skipped( ) else: is_skipped = prefix in ignored_layers + if "gate_up_proj" in prefix: + prefix_gate = prefix.replace("gate_up_proj", "gate_proj") + prefix_up = prefix.replace("gate_up_proj", "up_proj") + if prefix_gate in ignored_layers and prefix_up in ignored_layers: + is_skipped = True + elif "experts" in prefix: + is_skipped = any( + [ + prefix in layer_name + for layer_name in ignored_layers + if "experts" in layer_name + ] + ) assert is_skipped is not None return is_skipped diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index bbb0a3674..64bb885a6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -307,7 +307,10 @@ class ModelRunner: model_num_layers = ( self.model_config.num_nextn_predict_layers if self.is_draft_worker and model_has_mtp_layers - else self.model_config.num_hidden_layers + else max( + self.model_config.num_hidden_layers, + self.model_config.num_attention_layers, + ) ) self.start_layer = getattr(self.model, "start_layer", 0) self.end_layer = getattr(self.model, "end_layer", model_num_layers) diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py new file mode 100644 index 000000000..77cf718a9 --- /dev/null +++ b/python/sglang/srt/models/longcat_flash.py @@ -0,0 +1,1015 @@ +# Apache License, Version 2.0: +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License: +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import concurrent.futures +import logging +import os +from enum import IntEnum, auto +from typing import Any, Dict, Iterable, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm + +from sglang.srt.configs import LongcatFlashConfig +from sglang.srt.distributed import ( + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation +from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.amx_utils import PackWeightMethod +from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, + is_dp_attention_enabled, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ( + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.moe.ep_moe.kernels import zero_experts_compute_triton +from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class +from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE +from sglang.srt.layers.moe.topk import StandardTopKOutput, TopK +from sglang.srt.layers.quantization import deep_gemm_wrapper +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz +from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_dequant, + block_quant_to_tensor_quant, + channel_quant_to_tensor_quant, + normalize_e4m3fn_to_e4m3fnuz, + requant_weight_ue8m0_inplace, +) +from sglang.srt.layers.quantization.int8_utils import ( + block_dequant as int8_block_dequant, +) +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA +from sglang.srt.utils import ( + BumpAllocator, + LazyValue, + add_prefix, + bind_or_assign, + cpu_has_amx_support, + get_bool_env_var, + get_device_sm, + get_int_env_var, + is_cpu, + is_cuda, + is_flashinfer_available, + is_hip, + is_non_idle_and_non_empty, + is_npu, + is_sm100_supported, +) + +_is_hip = is_hip() +_is_cuda = is_cuda() +_is_npu = is_npu() +_is_fp8_fnuz = is_fp8_fnuz() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() +_device_sm = get_device_sm() + +if _is_cuda: + from sgl_kernel import ( + awq_dequantize, + bmm_fp8, + dsv3_fused_a_gemm, + dsv3_router_gemm, + merge_state_v2, + ) +elif _is_cpu and _is_cpu_amx_available: + pass +elif _is_hip: + from sglang.srt.layers.quantization.awq_triton import ( + awq_dequantize_triton as awq_dequantize, + ) +else: + from vllm._custom_ops import awq_dequantize + +logger = logging.getLogger(__name__) + + +class LongcatFlashMLP(nn.Module): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, + [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=add_prefix("gate_up_proj", prefix), + ) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=add_prefix("down_proj", prefix), + ) + if hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) + self.act_fn = SiluAndMul() + + def forward( + self, + x, + ): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class LongcatFlashRouter(nn.Module): + def __init__( + self, + config, + zero_expert_num=0, + rounter_params_dtype=torch.float32, + prefix: str = "", + ): + super().__init__() + self.n_routed_experts = config.n_routed_experts + self.n_routed_experts = self.n_routed_experts + zero_expert_num + self.rounter_params_dtype = rounter_params_dtype + self.classifier = ReplicatedLinear( + config.hidden_size, + self.n_routed_experts, + bias=config.router_bias, + params_dtype=rounter_params_dtype, + quant_config=None, + prefix=add_prefix("classifier", prefix), + ) + self.e_score_correction_bias = nn.Parameter( + torch.zeros((self.n_routed_experts), dtype=rounter_params_dtype) + ) + + def forward(self, hidden_states): + logits, _ = self.classifier(hidden_states.to(self.rounter_params_dtype)) + return logits + + +class LongcatFlashMoE(nn.Module): + + def __init__( + self, + config: LongcatFlashConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.config = config + self.layer_id = layer_id + self.routed_scaling_factor = config.routed_scaling_factor + self.num_experts = config.n_routed_experts + self.top_k = config.moe_topk + self.zero_expert_num = config.zero_expert_num + self.zero_expert_type = config.zero_expert_type + + if config.rounter_params_dtype == "float32": + self.rounter_params_dtype = torch.float32 + else: + self.rounter_params_dtype = torch.bfloat16 + + self.tp_size = get_tensor_model_parallel_world_size() + + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}." + ) + + if config.hidden_act != "silu": + raise ValueError( + f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now." + ) + + self.router = LongcatFlashRouter( + config=self.config, + zero_expert_num=self.zero_expert_num, + rounter_params_dtype=self.rounter_params_dtype, + prefix=add_prefix("router", prefix), + ) + + self.topk = TopK( + top_k=self.top_k, + renormalize=False, + use_grouped_topk=False, + correction_bias=self.router.e_score_correction_bias.data, + ) + self.topk.forward = self.topk.forward_native + + self.experts = get_moe_impl_class()( + num_experts=self.num_experts, + top_k=self.top_k, + layer_id=self.layer_id, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + quant_config=quant_config, + prefix=add_prefix("experts", prefix), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + num_tokens, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (num_tokens, n_experts) + router_logits = self.router(hidden_states) + topk_weights, topk_idx, _ = self.topk( + hidden_states, + router_logits, + ) + if self.zero_expert_type is not None: + zero_expert_result = zero_experts_compute_triton( + expert_indices=topk_idx, + expert_scales=topk_weights, + num_experts=self.num_experts, + zero_expert_type=self.zero_expert_type, + hidden_states=hidden_states, + ) + topk_output = StandardTopKOutput(topk_weights, topk_idx, _) + + final_hidden_states = self.experts(hidden_states, topk_output) + final_hidden_states *= self.routed_scaling_factor + + if self.zero_expert_type is not None and hidden_states.shape[0] > 0: + final_hidden_states += zero_expert_result.to(final_hidden_states.device) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states) + + return final_hidden_states.view(num_tokens, hidden_dim) + + def get_moe_weights(self): + return [ + x.data + for name, x in self.experts.named_parameters() + if name not in ["correction_bias"] + ] + + +class LongcatFlashDecoderLayer(nn.Module): + + def __init__( + self, + config: LongcatFlashConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_id = layer_id + self.alt_stream = alt_stream + self.self_attn = nn.ModuleList( + [ + DeepseekV2AttentionMLA( + config=config, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank, + kv_lora_rank=config.kv_lora_rank, + rope_theta=config.rope_theta, + rope_scaling=None, + max_position_embeddings=config.max_position_embeddings, + quant_config=( + None + if "self_attn" in getattr(config, "disable_quant_module", []) + else quant_config + ), + layer_id=layer_id * 2 + i, + reduce_results=False, + prefix=add_prefix(f"self_attn.{i}", prefix), + alt_stream=self.alt_stream, + ) + for i in range(2) + ] + ) + + self.input_layernorm = nn.ModuleList( + [RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for i in range(2)] + ) + self.post_attention_layernorm = nn.ModuleList( + [RMSNorm(config.hidden_size, eps=config.rms_norm_eps) for i in range(2)] + ) + + self.mlps = nn.ModuleList( + [ + LongcatFlashMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=( + None + if "mlps" in getattr(config, "disable_quant_module", []) + else quant_config + ), + prefix=add_prefix(f"mlps.{i}", prefix), + ) + for i in range(2) + ] + ) + + self.mlp = LongcatFlashMoE( + layer_id=self.layer_id, + config=config, + quant_config=quant_config, + prefix=add_prefix("mlp", prefix), + ) + + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() + + self.mlp_layer_scatter_modes = [ + LayerScatterModes.init_new( + layer_id=self.layer_id * 2 + i, + num_layers=config.num_hidden_layers, + is_layer_sparse=False, + is_previous_layer_sparse=False, + ) + for i in range(2) + ] + self.mlp_layer_communicator = [ + LayerCommunicator( + layer_scatter_modes=self.mlp_layer_scatter_modes[i], + input_layernorm=self.input_layernorm[i], + post_attention_layernorm=self.post_attention_layernorm[i], + ) + for i in range(2) + ] + + self.moe_layer_scatter_modes = LayerScatterModes.init_new( + layer_id=self.layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=True, + is_previous_layer_sparse=True, + ) + self.moe_layer_communicator = LayerCommunicator( + layer_scatter_modes=self.moe_layer_scatter_modes, + input_layernorm=self.input_layernorm[0], + post_attention_layernorm=self.post_attention_layernorm[0], + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + zero_allocator: BumpAllocator, + ) -> torch.Tensor: + # first_attn + hidden_states, residual = self.moe_layer_communicator.prepare_attn( + hidden_states, residual, forward_batch + ) + if hidden_states.shape[0] != 0: + hidden_states = self.self_attn[0]( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + zero_allocator=zero_allocator, + ) + + # moe + hidden_states, residual = self.moe_layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) + moe_hidden_states = hidden_states.clone() + moe_residual = residual.clone() + moe_hidden_states = self.mlp(moe_hidden_states) + moe_hidden_states, moe_residual = self.moe_layer_communicator.postprocess_layer( + moe_hidden_states, moe_residual, forward_batch + ) + + hidden_states, residual = self.forward_mlp( + hidden_states, positions, residual, forward_batch, zero_allocator + ) + + hidden_states = moe_hidden_states + hidden_states + return hidden_states, residual + + def forward_mlp( + self, hidden_states, positions, residual, forward_batch, zero_allocator + ): + # first_mlp + hidden_states = self.mlps[0](hidden_states) + # TP all_reduce + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + # second_attn + hidden_states, residual = self.mlp_layer_communicator[1].prepare_attn( + hidden_states, residual, forward_batch + ) + if hidden_states.shape[0] != 0: + hidden_states = self.self_attn[1]( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + zero_allocator=zero_allocator, + ) + + # second_mlp + hidden_states, residual = self.mlp_layer_communicator[1].prepare_mlp( + hidden_states, residual, forward_batch + ) + hidden_states = self.mlps[1](hidden_states) + # TP all_reduce + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + hidden_states, residual = self.mlp_layer_communicator[1].postprocess_layer( + hidden_states, residual, forward_batch + ) + + return hidden_states, residual + + +class LongcatFlashModel(nn.Module): + fall_back_to_pt_during_load = False + + def __init__( + self, + config: LongcatFlashConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not is_dp_attention_enabled(), + ) + + self.alt_stream = torch.cuda.Stream() + self.layers = nn.ModuleList( + [ + LongcatFlashDecoderLayer( + config, + layer_id, + quant_config=quant_config, + prefix=add_prefix(f"layers.{layer_id}", prefix), + alt_stream=self.alt_stream, + ) + for layer_id in range(config.num_hidden_layers) + ] + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self) -> torch.Tensor: + return self.embed_tokens + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + total_num_layers = len(self.layers) + device = input_embeds.device if input_embeds is not None else input_ids.device + zero_allocator = BumpAllocator( + buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1), + dtype=torch.float32, + device=device, + ) + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + residual = None + + for i in range(total_num_layers): + with get_global_expert_distribution_recorder().with_current_layer(i): + layer = self.layers[i] + hidden_states, residual = layer( + positions, hidden_states, forward_batch, residual, zero_allocator + ) + + if hidden_states.shape[0] != 0: + if residual is None: + hidden_states = self.norm(hidden_states) + else: + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class LongcatFlashForCausalLM(nn.Module): + # for quark model load + packed_modules_mapping = {} + + def __init__( + self, + config: LongcatFlashConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + + # for quark model load + # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None + self.fuse_qkv_a_proj = ( + hasattr(config, "q_lora_rank") and config.q_lora_rank is not None + ) + if self.fuse_qkv_a_proj: + self.packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + + self.config = config + self.tp_size = get_tensor_model_parallel_world_size() + self.quant_config = quant_config + self.model = LongcatFlashModel( + config, quant_config, prefix=add_prefix("model", prefix) + ) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("lm_head", prefix), + use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"], + ) + self.logits_processor = LogitsProcessor(config) + + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) + + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def post_load_weights(self, weight_names=None): + + # Perform post-processing after loading weights + if weight_names is None: + layer_ids = range(self.config.num_hidden_layers) + else: + layer_ids = set() + for name in weight_names: + if "kv_b_proj" in name: + layer_id = int(name.split(".")[2]) + if layer_id < self.config.num_hidden_layers: + layer_ids.add(layer_id) + + for layer_id in layer_ids: + for i in range(2): + self_attn = self.model.layers[layer_id].self_attn[i] + if hasattr(self_attn.kv_b_proj, "qweight"): + # AWQ compatible + if _is_cuda or _is_hip: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + ).T + else: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T + else: + w = self_attn.kv_b_proj.weight + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + # Fix deepseek v3 blockwise bmm by using deep_gemm + use_deep_gemm_bmm = False + + if w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + if ( + hasattr(self.quant_config, "weight_block_size") + and self.quant_config.weight_block_size is not None + ): + weight_block_size = self.quant_config.weight_block_size + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + + if ( + _is_cuda + and weight_block_size[0] == 128 + and weight_block_size[1] == 128 + ): + if ( + deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL + and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") + ): + block_scale = weight_scale + use_deep_gemm_bmm = True + else: + w = block_quant_dequant( + weight, + weight_scale, + weight_block_size, + torch.bfloat16, + ) + else: + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale + else: + if _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale + + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) + self_attn.w_scale = scale + + if w.dtype == torch.int8: + if hasattr(self.quant_config, "weight_block_size"): + # block-wise int8 need it + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w = int8_block_dequant( + weight, weight_scale, weight_block_size + ).to(torch.bfloat16) + else: + # channel-wise int8 need it + w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( + torch.bfloat16 + ) + + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + if not use_deep_gemm_bmm: + self_attn.w_kc = bind_or_assign( + self_attn.w_kc, + w_kc.transpose(1, 2).contiguous().transpose(1, 2), + ) + self_attn.w_vc = bind_or_assign( + self_attn.w_vc, w_vc.contiguous().transpose(1, 2) + ) + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): + self_attn.w_scale = bind_or_assign( + self_attn.w_scale, self_attn.kv_b_proj.weight_scale + ) + if _is_hip: + self_attn.w_scale *= 2.0 + # TODO: remove this after adding FP8 support in bmm cpu kernel + if ( + _is_cpu + and _is_cpu_amx_available + and w.dtype == torch.float8_e4m3fn + ): + self_attn.w_kc = ( + self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale + ) + self_attn.w_vc = ( + self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale + ) + else: + num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1] + num_tiles_n = self_attn.v_head_dim // weight_block_size[0] + ws_kc, ws_vc = block_scale.unflatten( + 0, (-1, (num_tiles_k + num_tiles_n)) + ).split([num_tiles_k, num_tiles_n], dim=1) + self_attn.w_scale_k = bind_or_assign( + self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous() + ) + self_attn.w_scale_v = bind_or_assign( + self_attn.w_scale_v, ws_vc.contiguous() + ) + self_attn.w_kc = bind_or_assign( + self_attn.w_kc, w_kc.transpose(1, 2).contiguous() + ) + self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous()) + self_attn.use_deep_gemm_bmm = True + + if self.config.mla_scale_q_lora: + self_attn.q_a_layernorm.weight.data *= ( + self.config.hidden_size / self.config.q_lora_rank + ) ** 0.5 + if self.config.mla_scale_kv_lora: + self_attn.kv_a_layernorm.weight.data *= ( + self.config.hidden_size / self.config.kv_lora_rank + ) ** 0.5 + + if ( + deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + and hasattr(self.quant_config, "weight_block_size") + and self.quant_config.weight_block_size is not None + ): + self._weight_requant_ue8m0() + + def _weight_requant_ue8m0(self): + weight_block_size = self.quant_config.weight_block_size + + for layer_id in range(self.config.num_hidden_layers): + layer = self.model.layers[layer_id] + for i in range(2): + for module in [ + layer.self_attn[i].fused_qkv_a_proj_with_mqa, + layer.self_attn[i].q_b_proj, + layer.self_attn[i].kv_b_proj, + layer.self_attn[i].o_proj, + ]: + requant_weight_ue8m0_inplace( + module.weight, module.weight_scale_inv, weight_block_size + ) + mlp = layer.mlps[i] + assert isinstance(mlp, LongcatFlashMLP) + for module in [ + mlp.gate_up_proj, + mlp.down_proj, + ]: + requant_weight_ue8m0_inplace( + module.weight, module.weight_scale_inv, weight_block_size + ) + + for layer_id in range(self.config.num_hidden_layers): + experts = layer.mlp.experts + if isinstance(experts, DeepEPMoE): + for w in [ + experts.w13_weight_fp8, + experts.w2_weight_fp8, + ]: + requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts, + ) + + # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None + fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( + self.config.q_lora_rank is not None + ) + cached_a_proj = {} if fuse_qkv_a_proj else None + + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + params_dict = dict(self.named_parameters()) + weight_names = [] + for name, loaded_weight in weights: + if "mtp" in name: + continue + weight_names.append(name) + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + futures.append( + executor.submit(weight_loader, param, loaded_weight, shard_id) + ) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + futures.append( + executor.submit( + weight_loader, + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if fuse_qkv_a_proj and ( + "q_a_proj" in name or "kv_a_proj_with_mqa" in name + ): + cached_a_proj[name] = loaded_weight + q_a_proj_name = ( + name + if "q_a_proj" in name + else name.replace("kv_a_proj_with_mqa", "q_a_proj") + ) + kv_a_proj_name = ( + name + if "kv_a_proj_with_mqa" in name + else name.replace("q_a_proj", "kv_a_proj_with_mqa") + ) + + # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter + if ( + q_a_proj_name in cached_a_proj + and kv_a_proj_name in cached_a_proj + ): + q_a_proj_weight = cached_a_proj[q_a_proj_name] + kv_a_proj_weight = cached_a_proj[kv_a_proj_name] + cat_dim = 0 + if self.quant_config is not None and ( + self.quant_config.get_name() == "awq" + or self.quant_config.get_name() == "awq_marlin" + or self.quant_config.get_name() == "moe_wna16" + ): + cat_dim = 1 + fused_weight = torch.cat( + [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim + ) + param_name = ( + name.replace( + "q_a_proj", "fused_qkv_a_proj_with_mqa" + ) + if "q_a_proj" in name + else name.replace( + "kv_a_proj_with_mqa", + "fused_qkv_a_proj_with_mqa", + ) + ) + param = params_dict[param_name] + + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + futures.append( + executor.submit(weight_loader, param, fused_weight) + ) + cached_a_proj.pop(q_a_proj_name) + cached_a_proj.pop(kv_a_proj_name) + else: + if ( + "k_scale" in name or "v_scale" in name + ) and name not in params_dict: + # modelopt attn kv scale is named differently + for scale in ["k_scale", "v_scale"]: + if scale in name: + name = name.replace( + f"{scale[0]}_proj", "attn_mqa" + ) + break + if name not in params_dict: + # modelopt ckpt contains not needed weights for MTP module: + # model.decoder.self_attn.attn_mqa.v_scale and + # model.decoder.self_attn.attn_mqa.k_scale + logger.warning(f"{name} not found in params_dict.") + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + futures.append( + executor.submit(weight_loader, param, loaded_weight) + ) + + # Wait for all tasks to complete and raise any exceptions. + for future in concurrent.futures.as_completed(futures): + future.result() + + self.post_load_weights(weight_names=weight_names) + + def get_embed_and_head(self): + return self.model.embed_tokens.weight, self.lm_head.weight + + def set_embed_and_head(self, embed, head): + del self.model.embed_tokens.weight + del self.lm_head.weight + self.model.embed_tokens.weight = embed + self.lm_head.weight = head + torch.cuda.empty_cache() + torch.cuda.synchronize() + + @classmethod + def get_model_config_for_expert_location(cls, config): + return ModelConfigForExpertLocation( + num_layers=config.num_hidden_layers, + num_logical_experts=config.n_routed_experts, + ) + + +EntryClass = [LongcatFlashForCausalLM] diff --git a/python/sglang/srt/models/longcat_flash_nextn.py b/python/sglang/srt/models/longcat_flash_nextn.py new file mode 100644 index 000000000..dfd455456 --- /dev/null +++ b/python/sglang/srt/models/longcat_flash_nextn.py @@ -0,0 +1,691 @@ +# Apache License, Version 2.0: +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# MIT License: +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import concurrent.futures +import logging +import os +from enum import IntEnum, auto +from typing import Any, Dict, Iterable, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm + +from sglang.srt.configs import LongcatFlashConfig +from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder +from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes +from sglang.srt.layers.dp_attention import ( + get_attention_tp_rank, + get_attention_tp_size, + is_dp_attention_enabled, +) +from sglang.srt.layers.layernorm import RMSNorm +from sglang.srt.layers.linear import ReplicatedLinear +from sglang.srt.layers.logits_processor import LogitsProcessor +from sglang.srt.layers.quantization import deep_gemm_wrapper +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz +from sglang.srt.layers.quantization.fp8_utils import ( + block_quant_dequant, + block_quant_to_tensor_quant, + channel_quant_to_tensor_quant, + normalize_e4m3fn_to_e4m3fnuz, + requant_weight_ue8m0_inplace, +) +from sglang.srt.layers.quantization.int8_utils import ( + block_dequant as int8_block_dequant, +) +from sglang.srt.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA +from sglang.srt.models.longcat_flash import LongcatFlashForCausalLM, LongcatFlashMLP +from sglang.srt.utils import ( + BumpAllocator, + LazyValue, + add_prefix, + bind_or_assign, + cpu_has_amx_support, + get_bool_env_var, + get_device_sm, + is_cpu, + is_cuda, + is_hip, + is_npu, +) + +_is_hip = is_hip() +_is_cuda = is_cuda() +_is_npu = is_npu() +_is_fp8_fnuz = is_fp8_fnuz() +_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip +_is_cpu_amx_available = cpu_has_amx_support() +_is_cpu = is_cpu() +_device_sm = get_device_sm() + +if _is_cuda: + from sgl_kernel import ( + awq_dequantize, + bmm_fp8, + dsv3_fused_a_gemm, + dsv3_router_gemm, + merge_state_v2, + ) +elif _is_cpu and _is_cpu_amx_available: + pass +elif _is_hip: + from sglang.srt.layers.quantization.awq_triton import ( + awq_dequantize_triton as awq_dequantize, + ) +else: + from vllm._custom_ops import awq_dequantize + + +logger = logging.getLogger(__name__) + + +class LongcatFlashDenseDecoderLayer(nn.Module): + + def __init__( + self, + config: LongcatFlashConfig, + layer_id: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + alt_stream: Optional[torch.cuda.Stream] = None, + ) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.layer_id = layer_id + self.alt_stream = alt_stream + + self.self_attn = DeepseekV2AttentionMLA( + config=config, + hidden_size=config.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank, + kv_lora_rank=config.kv_lora_rank, + rope_theta=config.rope_theta, + rope_scaling=None, + max_position_embeddings=config.max_position_embeddings, + quant_config=quant_config, + layer_id=layer_id, + reduce_results=False, + prefix=add_prefix(f"self_attn", prefix), + alt_stream=self.alt_stream, + ) + + self.mlp = LongcatFlashMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=add_prefix(f"mlps", prefix), + ) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() + self.layer_scatter_modes = LayerScatterModes.init_new( + layer_id=self.layer_id, + num_layers=config.num_hidden_layers, + is_layer_sparse=False, + is_previous_layer_sparse=False, + ) + self.layer_communicator = LayerCommunicator( + layer_scatter_modes=self.layer_scatter_modes, + input_layernorm=self.input_layernorm, + post_attention_layernorm=self.post_attention_layernorm, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: ForwardBatch, + residual: Optional[torch.Tensor], + zero_allocator: BumpAllocator, + ) -> torch.Tensor: + + hidden_states, residual = self.layer_communicator.prepare_attn( + hidden_states, residual, forward_batch + ) + if hidden_states.shape[0] != 0: + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + forward_batch=forward_batch, + zero_allocator=zero_allocator, + ) + + hidden_states, residual = self.layer_communicator.prepare_mlp( + hidden_states, residual, forward_batch + ) + hidden_states = self.mlp(hidden_states) + hidden_states, residual = self.layer_communicator.postprocess_layer( + hidden_states, residual, forward_batch + ) + return hidden_states, residual + + +class LongcatFlashModelNextN(nn.Module): + def __init__( + self, + config: LongcatFlashConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.vocab_size = config.vocab_size + self.alt_stream = torch.cuda.Stream() + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + enable_tp=not is_dp_attention_enabled(), + prefix=add_prefix("embed_tokens", prefix), + ) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.eh_proj = ReplicatedLinear( + 2 * config.hidden_size, + config.hidden_size, + bias=False, + quant_config=quant_config, + prefix=add_prefix("eh_proj", ""), + ) + self.decoder = LongcatFlashDenseDecoderLayer( + config, 0, quant_config=quant_config, alt_stream=self.alt_stream + ) + + self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self) -> torch.Tensor: + return self.embed_tokens + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + total_num_layers = 1 + device = input_embeds.device if input_embeds is not None else input_ids.device + zero_allocator = BumpAllocator( + buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1), + dtype=torch.float32, + device=device, + ) + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + if hidden_states.shape[0] > 0: + hidden_states, _ = self.eh_proj( + torch.cat( + ( + self.enorm(hidden_states), + self.hnorm(forward_batch.spec_info.hidden_states), + ), + dim=-1, + ) + ) + + residual = None + with get_global_expert_distribution_recorder().disable_this_region(): + hidden_states, residual = self.decoder( + positions, hidden_states, forward_batch, residual, zero_allocator + ) + + if not forward_batch.forward_mode.is_idle(): + if residual is not None: + hidden_states, _ = self.final_layernorm(hidden_states, residual) + else: + hidden_states = self.final_layernorm(hidden_states) + return hidden_states + + +class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM): + + def __init__( + self, + config: LongcatFlashConfig, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + self.config = config + self.quant_config = ( + None + if "mtp" in getattr(config, "disable_quant_module", []) + else quant_config + ) + self.model = LongcatFlashModelNextN(config, self.quant_config) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=self.quant_config, + ) + self.logits_processor = LogitsProcessor(config) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, forward_batch) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) + + def post_load_weights(self): + self_attn = self.model.decoder.self_attn + if hasattr(self_attn.kv_b_proj, "qweight"): + # AWQ compatible + if _is_cuda or _is_hip: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + ).T + else: + w = awq_dequantize( + self_attn.kv_b_proj.qweight, + self_attn.kv_b_proj.scales, + self_attn.kv_b_proj.qzeros, + 0, + 0, + 0, + ).T + else: + w = self_attn.kv_b_proj.weight + # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. + # This may affect the accuracy of fp8 model. + # Fix deepseek v3 blockwise bmm by using deep_gemm + use_deep_gemm_bmm = False + if w.dtype in ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + ): + if ( + hasattr(self.quant_config, "weight_block_size") + and self.quant_config.weight_block_size is not None + ): + weight_block_size = self.quant_config.weight_block_size + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + if _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale_inv, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + if ( + _is_cuda + and weight_block_size[0] == 128 + and weight_block_size[1] == 128 + ): + if ( + deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL + and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false") + ): + block_scale = weight_scale + use_deep_gemm_bmm = True + else: + w = block_quant_dequant( + weight, + weight_scale, + weight_block_size, + torch.bfloat16, + ) + else: + w, scale = block_quant_to_tensor_quant( + weight, weight_scale, weight_block_size + ) + self_attn.w_scale = scale + else: + if _is_fp8_fnuz: + weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz( + weight=w, + weight_scale=self_attn.kv_b_proj.weight_scale, + input_scale=None, + ) + else: + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale + w, scale = channel_quant_to_tensor_quant(weight, weight_scale) + self_attn.w_scale = scale + if w.dtype == torch.int8: + if hasattr(self.quant_config, "weight_block_size"): + # block-wise int8 need it + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w = int8_block_dequant(weight, weight_scale, weight_block_size).to( + torch.bfloat16 + ) + else: + # channel-wise int8 need it + w = w.to(torch.bfloat16) * self_attn.kv_b_proj.weight_scale.to( + torch.bfloat16 + ) + w_kc, w_vc = w.unflatten( + 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) + ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) + if not use_deep_gemm_bmm: + self_attn.w_kc = bind_or_assign( + self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2) + ) + self_attn.w_vc = bind_or_assign( + self_attn.w_vc, w_vc.contiguous().transpose(1, 2) + ) + if ( + hasattr(self_attn.kv_b_proj, "weight_scale") + and self_attn.w_scale is None + ): + self_attn.w_scale = bind_or_assign( + self_attn.w_scale, self_attn.kv_b_proj.weight_scale + ) + if _is_hip: + self_attn.w_scale *= 2.0 + # TODO: remove this after adding FP8 support in bmm cpu kernel + if _is_cpu and _is_cpu_amx_available and w.dtype == torch.float8_e4m3fn: + self_attn.w_kc = self_attn.w_kc.to(torch.bfloat16) * self_attn.w_scale + self_attn.w_vc = self_attn.w_vc.to(torch.bfloat16) * self_attn.w_scale + else: + num_tiles_k = self_attn.qk_nope_head_dim // weight_block_size[1] + num_tiles_n = self_attn.v_head_dim // weight_block_size[0] + ws_kc, ws_vc = block_scale.unflatten( + 0, (-1, (num_tiles_k + num_tiles_n)) + ).split([num_tiles_k, num_tiles_n], dim=1) + self_attn.w_scale_k = bind_or_assign( + self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous() + ) + self_attn.w_scale_v = bind_or_assign( + self_attn.w_scale_v, ws_vc.contiguous() + ) + self_attn.w_kc = bind_or_assign( + self_attn.w_kc, w_kc.transpose(1, 2).contiguous() + ) + self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous()) + self_attn.use_deep_gemm_bmm = True + + if self.config.mla_scale_q_lora: + self_attn.q_a_layernorm.weight.data *= ( + self.config.hidden_size / self.config.q_lora_rank + ) ** 0.5 + if self.config.mla_scale_kv_lora: + self_attn.kv_a_layernorm.weight.data *= ( + self.config.hidden_size / self.config.kv_lora_rank + ) ** 0.5 + + if ( + deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM + and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + and hasattr(self.quant_config, "weight_block_size") + and self.quant_config.weight_block_size is not None + ): + self._weight_requant_ue8m0() + + def _weight_requant_ue8m0(self): + weight_block_size = self.quant_config.weight_block_size + layer = self.model.decoder + for module in [ + layer.self_attn.fused_qkv_a_proj_with_mqa, + layer.self_attn.q_b_proj, + layer.self_attn.kv_b_proj, + layer.self_attn.o_proj, + ]: + requant_weight_ue8m0_inplace( + module.weight, module.weight_scale_inv, weight_block_size + ) + mlp = layer.mlps + assert isinstance(mlp, LongcatFlashMLP) + for module in [ + mlp.gate_up_proj, + mlp.down_proj, + ]: + requant_weight_ue8m0_inplace( + module.weight, module.weight_scale_inv, weight_block_size + ) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None + fuse_qkv_a_proj = hasattr(self.config, "q_lora_rank") and ( + self.config.q_lora_rank is not None + ) + cached_a_proj = {} if fuse_qkv_a_proj else None + + nextn_layer_prefix = "model.layers.0" + nextn_spec_weight_names = [ + "shared_head.norm", + "eh_proj", + "enorm", + "hnorm", + "final_layernorm", + ] + + weight_names_mapping = { + "model.mtp.embed_tokens.weight": "embed_tokens.weight", + "model.mtp.layers.0.eh_proj.weight": "eh_proj.weight", + "model.mtp.layers.0.eh_proj.weight_scale_inv": "eh_proj.weight_scale_inv", + "model.mtp.layers.0.enorm.m.weight": "enorm.weight", + "model.mtp.layers.0.hnorm.m.weight": "hnorm.weight", + "model.mtp.layers.0.input_layernorm.weight": "layers.0.input_layernorm.weight", + "model.mtp.layers.0.post_attention_layernorm.weight": "layers.0.post_attention_layernorm.weight", + "model.mtp.layers.0.self_attn.kv_a_layernorm.weight": "layers.0.self_attn.kv_a_layernorm.weight", + "model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight": "layers.0.self_attn.kv_a_proj_with_mqa.weight", + "model.mtp.layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv": "layers.0.self_attn.kv_a_proj_with_mqa.weight_scale_inv", + "model.mtp.layers.0.self_attn.kv_b_proj.weight": "layers.0.self_attn.kv_b_proj.weight", + "model.mtp.layers.0.self_attn.kv_b_proj.weight_scale_inv": "layers.0.self_attn.kv_b_proj.weight_scale_inv", + "model.mtp.layers.0.self_attn.o_proj.weight": "layers.0.self_attn.o_proj.weight", + "model.mtp.layers.0.self_attn.o_proj.weight_scale_inv": "layers.0.self_attn.o_proj.weight_scale_inv", + "model.mtp.layers.0.self_attn.q_a_layernorm.weight": "layers.0.self_attn.q_a_layernorm.weight", + "model.mtp.layers.0.self_attn.q_a_proj.weight": "layers.0.self_attn.q_a_proj.weight", + "model.mtp.layers.0.self_attn.q_a_proj.weight_scale_inv": "layers.0.self_attn.q_a_proj.weight_scale_inv", + "model.mtp.layers.0.self_attn.q_b_proj.weight": "layers.0.self_attn.q_b_proj.weight", + "model.mtp.layers.0.self_attn.q_b_proj.weight_scale_inv": "layers.0.self_attn.q_b_proj.weight_scale_inv", + "model.mtp.layers.0.transformer_layer.mlp.down_proj.weight": "layers.0.mlp.down_proj.weight", + "model.mtp.layers.0.transformer_layer.mlp.down_proj.weight_scale_inv": "layers.0.mlp.down_proj.weight_scale_inv", + "model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight": "layers.0.mlp.gate_proj.weight", + "model.mtp.layers.0.transformer_layer.mlp.gate_proj.weight_scale_inv": "layers.0.mlp.gate_proj.weight_scale_inv", + "model.mtp.layers.0.transformer_layer.mlp.up_proj.weight": "layers.0.mlp.up_proj.weight", + "model.mtp.layers.0.transformer_layer.mlp.up_proj.weight_scale_inv": "layers.0.mlp.up_proj.weight_scale_inv", + "model.mtp.norm.weight": "layers.0.final_layernorm.weight", + } + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [] + params_dict = dict(self.named_parameters()) + weight_names = [] + for name, loaded_weight in weights: + if ".mtp." not in name: + continue + if name in weight_names_mapping: + name = weight_names_mapping[name] + if name.startswith("layers.0"): + name = "model." + name + if ( + name.startswith("enorm") + or name.startswith("hnorm") + or name.startswith("eh_proj") + ): + name = nextn_layer_prefix + "." + name + if not name.startswith(nextn_layer_prefix): + continue + + # Use shared head and embed weights from target model + if "shared_head.head" in name or "embed_tokens" in name: + continue + + is_decoder = True + # For nextn specific weights + for weight_name in nextn_spec_weight_names: + if weight_name in name: + name = name.replace(nextn_layer_prefix, "model") + is_decoder = False + break + # For decoder layer weights + if is_decoder: + name = name.replace(nextn_layer_prefix, "model.decoder") + + weight_names.append(name) + if "rotary_emb.inv_freq" in name: + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if ("mlp.experts." in name) and name not in params_dict: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + futures.append( + executor.submit(weight_loader, param, loaded_weight, shard_id) + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if fuse_qkv_a_proj and ( + "q_a_proj" in name or "kv_a_proj_with_mqa" in name + ): + cached_a_proj[name] = loaded_weight + q_a_proj_name = ( + name + if "q_a_proj" in name + else name.replace("kv_a_proj_with_mqa", "q_a_proj") + ) + kv_a_proj_name = ( + name + if "kv_a_proj_with_mqa" in name + else name.replace("q_a_proj", "kv_a_proj_with_mqa") + ) + + # When both q_a_proj and kv_a_proj_with_mqa has been cached, load the fused weight to parameter + if ( + q_a_proj_name in cached_a_proj + and kv_a_proj_name in cached_a_proj + ): + q_a_proj_weight = cached_a_proj[q_a_proj_name] + kv_a_proj_weight = cached_a_proj[kv_a_proj_name] + cat_dim = 0 + if self.quant_config is not None and ( + self.quant_config.get_name() == "awq" + or self.quant_config.get_name() == "awq_marlin" + or self.quant_config.get_name() == "moe_wna16" + ): + cat_dim = 1 + fused_weight = torch.cat( + [q_a_proj_weight, kv_a_proj_weight], dim=cat_dim + ) + param_name = ( + name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa") + if "q_a_proj" in name + else name.replace( + "kv_a_proj_with_mqa", + "fused_qkv_a_proj_with_mqa", + ) + ) + param = params_dict[param_name] + + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + futures.append( + executor.submit(weight_loader, param, fused_weight) + ) + cached_a_proj.pop(q_a_proj_name) + cached_a_proj.pop(kv_a_proj_name) + else: + if ( + "k_scale" in name or "v_scale" in name + ) and name not in params_dict: + # modelopt attn kv scale is named differently + for scale in ["k_scale", "v_scale"]: + if scale in name: + name = name.replace(f"{scale[0]}_proj", "attn_mqa") + break + if name not in params_dict: + # modelopt ckpt contains not needed weights for MTP module: + # model.decoder.self_attn.attn_mqa.v_scale and + # model.decoder.self_attn.attn_mqa.k_scale + logger.warning(f"{name} not found in params_dict.") + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + futures.append( + executor.submit(weight_loader, param, loaded_weight) + ) + self.post_load_weights() + + +EntryClass = [LongcatFlashForCausalLMNextN]