[Fix] fix the issue encountered when inference LongCat-Flash/MTP EP MoE on b200 (#9916)
This commit is contained in:
@@ -651,9 +651,6 @@ class LongcatFlashForCausalLM(nn.Module):
|
|||||||
).T
|
).T
|
||||||
else:
|
else:
|
||||||
w = self_attn.kv_b_proj.weight
|
w = self_attn.kv_b_proj.weight
|
||||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
|
||||||
# This may affect the accuracy of fp8 model.
|
|
||||||
# Fix deepseek v3 blockwise bmm by using deep_gemm
|
|
||||||
use_deep_gemm_bmm = False
|
use_deep_gemm_bmm = False
|
||||||
|
|
||||||
if w.dtype in (
|
if w.dtype in (
|
||||||
@@ -790,6 +787,9 @@ class LongcatFlashForCausalLM(nn.Module):
|
|||||||
self.config.hidden_size / self.config.kv_lora_rank
|
self.config.hidden_size / self.config.kv_lora_rank
|
||||||
) ** 0.5
|
) ** 0.5
|
||||||
|
|
||||||
|
# TODO(linguoyuan) EPMoE not support DEEPGEMM_BLACKWELL, DeepEP needs to be supported in the future
|
||||||
|
deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 = False
|
||||||
|
|
||||||
if (
|
if (
|
||||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||||
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||||
@@ -804,24 +804,35 @@ class LongcatFlashForCausalLM(nn.Module):
|
|||||||
for layer_id in range(self.config.num_hidden_layers):
|
for layer_id in range(self.config.num_hidden_layers):
|
||||||
layer = self.model.layers[layer_id]
|
layer = self.model.layers[layer_id]
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
for module in [
|
self_attn = layer.self_attn[i]
|
||||||
layer.self_attn[i].fused_qkv_a_proj_with_mqa,
|
module_list = [
|
||||||
layer.self_attn[i].q_b_proj,
|
self_attn.kv_b_proj,
|
||||||
layer.self_attn[i].kv_b_proj,
|
self_attn.o_proj,
|
||||||
layer.self_attn[i].o_proj,
|
]
|
||||||
]:
|
|
||||||
requant_weight_ue8m0_inplace(
|
if self.config.q_lora_rank is not None:
|
||||||
module.weight, module.weight_scale_inv, weight_block_size
|
module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
|
||||||
)
|
module_list.append(self_attn.q_b_proj)
|
||||||
|
else:
|
||||||
|
module_list.append(self_attn.kv_a_proj_with_mqa)
|
||||||
|
module_list.append(self_attn.q_proj)
|
||||||
|
|
||||||
|
for module in module_list:
|
||||||
|
if hasattr(module, "weight_scale_inv"):
|
||||||
|
requant_weight_ue8m0_inplace(
|
||||||
|
module.weight, module.weight_scale_inv, weight_block_size
|
||||||
|
)
|
||||||
|
|
||||||
mlp = layer.mlps[i]
|
mlp = layer.mlps[i]
|
||||||
assert isinstance(mlp, LongcatFlashMLP)
|
assert isinstance(mlp, LongcatFlashMLP)
|
||||||
for module in [
|
for module in [
|
||||||
mlp.gate_up_proj,
|
mlp.gate_up_proj,
|
||||||
mlp.down_proj,
|
mlp.down_proj,
|
||||||
]:
|
]:
|
||||||
requant_weight_ue8m0_inplace(
|
if hasattr(module, "weight_scale_inv"):
|
||||||
module.weight, module.weight_scale_inv, weight_block_size
|
requant_weight_ue8m0_inplace(
|
||||||
)
|
module.weight, module.weight_scale_inv, weight_block_size
|
||||||
|
)
|
||||||
|
|
||||||
for layer_id in range(self.config.num_hidden_layers):
|
for layer_id in range(self.config.num_hidden_layers):
|
||||||
experts = layer.mlp.experts
|
experts = layer.mlp.experts
|
||||||
|
|||||||
@@ -344,9 +344,6 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
|
|||||||
).T
|
).T
|
||||||
else:
|
else:
|
||||||
w = self_attn.kv_b_proj.weight
|
w = self_attn.kv_b_proj.weight
|
||||||
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
|
||||||
# This may affect the accuracy of fp8 model.
|
|
||||||
# Fix deepseek v3 blockwise bmm by using deep_gemm
|
|
||||||
use_deep_gemm_bmm = False
|
use_deep_gemm_bmm = False
|
||||||
if w.dtype in (
|
if w.dtype in (
|
||||||
torch.float8_e4m3fn,
|
torch.float8_e4m3fn,
|
||||||
@@ -480,24 +477,35 @@ class LongcatFlashForCausalLMNextN(LongcatFlashForCausalLM):
|
|||||||
def _weight_requant_ue8m0(self):
|
def _weight_requant_ue8m0(self):
|
||||||
weight_block_size = self.quant_config.weight_block_size
|
weight_block_size = self.quant_config.weight_block_size
|
||||||
layer = self.model.decoder
|
layer = self.model.decoder
|
||||||
for module in [
|
self_attn = layer.self_attn
|
||||||
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
module_list = [
|
||||||
layer.self_attn.q_b_proj,
|
self_attn.kv_b_proj,
|
||||||
layer.self_attn.kv_b_proj,
|
self_attn.o_proj,
|
||||||
layer.self_attn.o_proj,
|
]
|
||||||
]:
|
|
||||||
requant_weight_ue8m0_inplace(
|
if self.config.q_lora_rank is not None:
|
||||||
module.weight, module.weight_scale_inv, weight_block_size
|
module_list.append(self_attn.fused_qkv_a_proj_with_mqa)
|
||||||
)
|
module_list.append(self_attn.q_b_proj)
|
||||||
|
else:
|
||||||
|
module_list.append(self_attn.kv_a_proj_with_mqa)
|
||||||
|
module_list.append(self_attn.q_proj)
|
||||||
|
|
||||||
|
for module in module_list:
|
||||||
|
if hasattr(module, "weight_scale_inv"):
|
||||||
|
requant_weight_ue8m0_inplace(
|
||||||
|
module.weight, module.weight_scale_inv, weight_block_size
|
||||||
|
)
|
||||||
|
|
||||||
mlp = layer.mlps
|
mlp = layer.mlps
|
||||||
assert isinstance(mlp, LongcatFlashMLP)
|
assert isinstance(mlp, LongcatFlashMLP)
|
||||||
for module in [
|
for module in [
|
||||||
mlp.gate_up_proj,
|
mlp.gate_up_proj,
|
||||||
mlp.down_proj,
|
mlp.down_proj,
|
||||||
]:
|
]:
|
||||||
requant_weight_ue8m0_inplace(
|
if hasattr(module, "weight_scale_inv"):
|
||||||
module.weight, module.weight_scale_inv, weight_block_size
|
requant_weight_ue8m0_inplace(
|
||||||
)
|
module.weight, module.weight_scale_inv, weight_block_size
|
||||||
|
)
|
||||||
|
|
||||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
stacked_params_mapping = [
|
stacked_params_mapping = [
|
||||||
|
|||||||
Reference in New Issue
Block a user