From b7361cc4441d7843d4799da4bf78c3654a39422e Mon Sep 17 00:00:00 2001 From: Guoyuan Lin Date: Tue, 2 Sep 2025 18:11:14 +0800 Subject: [PATCH] [Fix] fix the issue encountered when inference LongCat-Flash/MTP EP MoE on b200 (#9916) --- python/sglang/srt/models/longcat_flash.py | 41 ++++++++++++------- .../sglang/srt/models/longcat_flash_nextn.py | 38 ++++++++++------- 2 files changed, 49 insertions(+), 30 deletions(-) diff --git a/python/sglang/srt/models/longcat_flash.py b/python/sglang/srt/models/longcat_flash.py index 77cf718a9..9531cb83e 100644 --- a/python/sglang/srt/models/longcat_flash.py +++ b/python/sglang/srt/models/longcat_flash.py @@ -651,9 +651,6 @@ class LongcatFlashForCausalLM(nn.Module): ).T else: w = self_attn.kv_b_proj.weight - # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. - # This may affect the accuracy of fp8 model. - # Fix deepseek v3 blockwise bmm by using deep_gemm use_deep_gemm_bmm = False if w.dtype in ( @@ -790,6 +787,9 @@ class LongcatFlashForCausalLM(nn.Module): self.config.hidden_size / self.config.kv_lora_rank ) ** 0.5 + # TODO(linguoyuan) EPMoE not support DEEPGEMM_BLACKWELL, DeepEP needs to be supported in the future + deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 = False + if ( deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 @@ -804,24 +804,35 @@ class LongcatFlashForCausalLM(nn.Module): for layer_id in range(self.config.num_hidden_layers): layer = self.model.layers[layer_id] for i in range(2): - for module in [ - layer.self_attn[i].fused_qkv_a_proj_with_mqa, - layer.self_attn[i].q_b_proj, - layer.self_attn[i].kv_b_proj, - layer.self_attn[i].o_proj, - ]: - requant_weight_ue8m0_inplace( - module.weight, module.weight_scale_inv, weight_block_size - ) + self_attn = layer.self_attn[i] + module_list = [ + self_attn.kv_b_proj, + self_attn.o_proj, + ] + + if self.config.q_lora_rank is not None: + module_list.append(self_attn.fused_qkv_a_proj_with_mqa) + module_list.append(self_attn.q_b_proj) + else: + module_list.append(self_attn.kv_a_proj_with_mqa) + module_list.append(self_attn.q_proj) + + for module in module_list: + if hasattr(module, "weight_scale_inv"): + requant_weight_ue8m0_inplace( + module.weight, module.weight_scale_inv, weight_block_size + ) + mlp = layer.mlps[i] assert isinstance(mlp, LongcatFlashMLP) for module in [ mlp.gate_up_proj, mlp.down_proj, ]: - requant_weight_ue8m0_inplace( - module.weight, module.weight_scale_inv, weight_block_size - ) + if hasattr(module, "weight_scale_inv"): + requant_weight_ue8m0_inplace( + module.weight, module.weight_scale_inv, weight_block_size + ) for layer_id in range(self.config.num_hidden_layers): experts = layer.mlp.experts diff --git a/python/sglang/srt/models/longcat_flash_nextn.py b/python/sglang/srt/models/longcat_flash_nextn.py index dfd455456..64a4265c5 100644 --- a/python/sglang/srt/models/longcat_flash_nextn.py +++ b/python/sglang/srt/models/longcat_flash_nextn.py @@ -344,9 +344,6 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM): ).T else: w = self_attn.kv_b_proj.weight - # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`. - # This may affect the accuracy of fp8 model. - # Fix deepseek v3 blockwise bmm by using deep_gemm use_deep_gemm_bmm = False if w.dtype in ( torch.float8_e4m3fn, @@ -480,24 +477,35 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM): def _weight_requant_ue8m0(self): weight_block_size = self.quant_config.weight_block_size layer = self.model.decoder - for module in [ - layer.self_attn.fused_qkv_a_proj_with_mqa, - layer.self_attn.q_b_proj, - layer.self_attn.kv_b_proj, - layer.self_attn.o_proj, - ]: - requant_weight_ue8m0_inplace( - module.weight, module.weight_scale_inv, weight_block_size - ) + self_attn = layer.self_attn + module_list = [ + self_attn.kv_b_proj, + self_attn.o_proj, + ] + + if self.config.q_lora_rank is not None: + module_list.append(self_attn.fused_qkv_a_proj_with_mqa) + module_list.append(self_attn.q_b_proj) + else: + module_list.append(self_attn.kv_a_proj_with_mqa) + module_list.append(self_attn.q_proj) + + for module in module_list: + if hasattr(module, "weight_scale_inv"): + requant_weight_ue8m0_inplace( + module.weight, module.weight_scale_inv, weight_block_size + ) + mlp = layer.mlps assert isinstance(mlp, LongcatFlashMLP) for module in [ mlp.gate_up_proj, mlp.down_proj, ]: - requant_weight_ue8m0_inplace( - module.weight, module.weight_scale_inv, weight_block_size - ) + if hasattr(module, "weight_scale_inv"): + requant_weight_ue8m0_inplace( + module.weight, module.weight_scale_inv, weight_block_size + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [