[Fix] fix the issue encountered when inference LongCat-Flash/MTP EP MoE on b200 (#9916)
This commit is contained in:
@@ -651,9 +651,6 @@ class LongcatFlashForCausalLM(nn.Module):
|
||||
).T
|
||||
else:
|
||||
w = self_attn.kv_b_proj.weight
|
||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
||||
# This may affect the accuracy of fp8 model.
|
||||
# Fix deepseek v3 blockwise bmm by using deep_gemm
|
||||
use_deep_gemm_bmm = False
|
||||
|
||||
if w.dtype in (
|
||||
@@ -790,6 +787,9 @@ class LongcatFlashForCausalLM(nn.Module):
|
||||
self.config.hidden_size / self.config.kv_lora_rank
|
||||
) ** 0.5
|
||||
|
||||
# TODO(linguoyuan) EPMoE not support DEEPGEMM_BLACKWELL, DeepEP needs to be supported in the future
|
||||
deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 = False
|
||||
|
||||
if (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
@@ -804,24 +804,35 @@ class LongcatFlashForCausalLM(nn.Module):
|
||||
for layer_id in range(self.config.num_hidden_layers):
|
||||
layer = self.model.layers[layer_id]
|
||||
for i in range(2):
|
||||
for module in [
|
||||
layer.self_attn[i].fused_qkv_a_proj_with_mqa,
|
||||
layer.self_attn[i].q_b_proj,
|
||||
layer.self_attn[i].kv_b_proj,
|
||||
layer.self_attn[i].o_proj,
|
||||
]:
|
||||
requant_weight_ue8m0_inplace(
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
self_attn = layer.self_attn[i]
|
||||
module_list = [
|
||||
self_attn.kv_b_proj,
|
||||
self_attn.o_proj,
|
||||
]
|
||||
|
||||
if self.config.q_lora_rank is not None:
|
||||
module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
|
||||
module_list.append(self_attn.q_b_proj)
|
||||
else:
|
||||
module_list.append(self_attn.kv_a_proj_with_mqa)
|
||||
module_list.append(self_attn.q_proj)
|
||||
|
||||
for module in module_list:
|
||||
if hasattr(module, "weight_scale_inv"):
|
||||
requant_weight_ue8m0_inplace(
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
|
||||
mlp = layer.mlps[i]
|
||||
assert isinstance(mlp, LongcatFlashMLP)
|
||||
for module in [
|
||||
mlp.gate_up_proj,
|
||||
mlp.down_proj,
|
||||
]:
|
||||
requant_weight_ue8m0_inplace(
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
if hasattr(module, "weight_scale_inv"):
|
||||
requant_weight_ue8m0_inplace(
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
|
||||
for layer_id in range(self.config.num_hidden_layers):
|
||||
experts = layer.mlp.experts
|
||||
|
||||
@@ -344,9 +344,6 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
|
||||
).T
|
||||
else:
|
||||
w = self_attn.kv_b_proj.weight
|
||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
||||
# This may affect the accuracy of fp8 model.
|
||||
# Fix deepseek v3 blockwise bmm by using deep_gemm
|
||||
use_deep_gemm_bmm = False
|
||||
if w.dtype in (
|
||||
torch.float8_e4m3fn,
|
||||
@@ -480,24 +477,35 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
|
||||
def _weight_requant_ue8m0(self):
|
||||
weight_block_size = self.quant_config.weight_block_size
|
||||
layer = self.model.decoder
|
||||
for module in [
|
||||
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
||||
layer.self_attn.q_b_proj,
|
||||
layer.self_attn.kv_b_proj,
|
||||
layer.self_attn.o_proj,
|
||||
]:
|
||||
requant_weight_ue8m0_inplace(
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
self_attn = layer.self_attn
|
||||
module_list = [
|
||||
self_attn.kv_b_proj,
|
||||
self_attn.o_proj,
|
||||
]
|
||||
|
||||
if self.config.q_lora_rank is not None:
|
||||
module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
|
||||
module_list.append(self_attn.q_b_proj)
|
||||
else:
|
||||
module_list.append(self_attn.kv_a_proj_with_mqa)
|
||||
module_list.append(self_attn.q_proj)
|
||||
|
||||
for module in module_list:
|
||||
if hasattr(module, "weight_scale_inv"):
|
||||
requant_weight_ue8m0_inplace(
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
|
||||
mlp = layer.mlps
|
||||
assert isinstance(mlp, LongcatFlashMLP)
|
||||
for module in [
|
||||
mlp.gate_up_proj,
|
||||
mlp.down_proj,
|
||||
]:
|
||||
requant_weight_ue8m0_inplace(
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
if hasattr(module, "weight_scale_inv"):
|
||||
requant_weight_ue8m0_inplace(
|
||||
module.weight, module.weight_scale_inv, weight_block_size
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
Reference in New Issue
Block a user