[Fix] fix the issue encountered when inference LongCat-Flash/MTP EP MoE on b200 (#9916)

This commit is contained in:
Guoyuan Lin
2025-09-02 18:11:14 +08:00
committed by GitHub
parent a96c5b5c14
commit b7361cc444
2 changed files with 49 additions and 30 deletions

View File

@@ -651,9 +651,6 @@ class LongcatFlashForCausalLM(nn.Module):
).T
else:
w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm = False
if w.dtype in (
@@ -790,6 +787,9 @@ class LongcatFlashForCausalLM(nn.Module):
self.config.hidden_size / self.config.kv_lora_rank
) ** 0.5
# TODO(linguoyuan) EPMoE not support DEEPGEMM_BLACKWELL, DeepEP needs to be supported in the future
deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 = False
if (
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
@@ -804,24 +804,35 @@ class LongcatFlashForCausalLM(nn.Module):
for layer_id in range(self.config.num_hidden_layers):
layer = self.model.layers[layer_id]
for i in range(2):
for module in [
layer.self_attn[i].fused_qkv_a_proj_with_mqa,
layer.self_attn[i].q_b_proj,
layer.self_attn[i].kv_b_proj,
layer.self_attn[i].o_proj,
]:
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
self_attn = layer.self_attn[i]
module_list = [
self_attn.kv_b_proj,
self_attn.o_proj,
]
if self.config.q_lora_rank is not None:
module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
module_list.append(self_attn.q_b_proj)
else:
module_list.append(self_attn.kv_a_proj_with_mqa)
module_list.append(self_attn.q_proj)
for module in module_list:
if hasattr(module, "weight_scale_inv"):
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
mlp = layer.mlps[i]
assert isinstance(mlp, LongcatFlashMLP)
for module in [
mlp.gate_up_proj,
mlp.down_proj,
]:
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
if hasattr(module, "weight_scale_inv"):
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
for layer_id in range(self.config.num_hidden_layers):
experts = layer.mlp.experts

View File

@@ -344,9 +344,6 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
).T
else:
w = self_attn.kv_b_proj.weight
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
# This may affect the accuracy of fp8 model.
# Fix deepseek v3 blockwise bmm by using deep_gemm
use_deep_gemm_bmm = False
if w.dtype in (
torch.float8_e4m3fn,
@@ -480,24 +477,35 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
def _weight_requant_ue8m0(self):
weight_block_size = self.quant_config.weight_block_size
layer = self.model.decoder
for module in [
layer.self_attn.fused_qkv_a_proj_with_mqa,
layer.self_attn.q_b_proj,
layer.self_attn.kv_b_proj,
layer.self_attn.o_proj,
]:
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
self_attn = layer.self_attn
module_list = [
self_attn.kv_b_proj,
self_attn.o_proj,
]
if self.config.q_lora_rank is not None:
module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
module_list.append(self_attn.q_b_proj)
else:
module_list.append(self_attn.kv_a_proj_with_mqa)
module_list.append(self_attn.q_proj)
for module in module_list:
if hasattr(module, "weight_scale_inv"):
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
mlp = layer.mlps
assert isinstance(mlp, LongcatFlashMLP)
for module in [
mlp.gate_up_proj,
mlp.down_proj,
]:
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
if hasattr(module, "weight_scale_inv"):
requant_weight_ue8m0_inplace(
module.weight, module.weight_scale_inv, weight_block_size
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [