Fix MTP with Deepseek R1 Fp4 (#7376)
This commit is contained in:
@@ -330,6 +330,12 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.tp_rank = get_tensor_model_parallel_rank()
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
self.num_experts = num_experts
|
self.num_experts = num_experts
|
||||||
self.expert_map = None
|
self.expert_map = None
|
||||||
|
|
||||||
|
if enable_flashinfer_moe and quant_config is None:
|
||||||
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
||||||
|
enable_flashinfer_moe = False
|
||||||
|
enable_ep_moe = False
|
||||||
|
|
||||||
self.enable_flashinfer_moe = enable_flashinfer_moe
|
self.enable_flashinfer_moe = enable_flashinfer_moe
|
||||||
if enable_ep_moe:
|
if enable_ep_moe:
|
||||||
assert (
|
assert (
|
||||||
|
|||||||
@@ -44,6 +44,12 @@ class DeepseekModelNextN(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
||||||
|
logger.warning(
|
||||||
|
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
|
||||||
|
)
|
||||||
|
quant_config = None
|
||||||
|
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
|||||||
@@ -2201,7 +2201,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
||||||
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
||||||
cat_dim = 0
|
cat_dim = 0
|
||||||
if (
|
if self.quant_config is not None and (
|
||||||
self.quant_config.get_name() == "awq"
|
self.quant_config.get_name() == "awq"
|
||||||
or self.quant_config.get_name() == "moe_wna16"
|
or self.quant_config.get_name() == "moe_wna16"
|
||||||
):
|
):
|
||||||
@@ -2232,6 +2232,13 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
for scale in ["k_scale", "v_scale"]:
|
for scale in ["k_scale", "v_scale"]:
|
||||||
if scale in name:
|
if scale in name:
|
||||||
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
|
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
|
||||||
|
break
|
||||||
|
if name not in params_dict:
|
||||||
|
# modelopt ckpt contains not needed weights for MTP module:
|
||||||
|
# model.decoder.self_attn.attn_mqa.v_scale and
|
||||||
|
# model.decoder.self_attn.attn_mqa.k_scale
|
||||||
|
logger.warning(f"{name} not found in params_dict.")
|
||||||
|
continue
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(
|
weight_loader = getattr(
|
||||||
param, "weight_loader", default_weight_loader
|
param, "weight_loader", default_weight_loader
|
||||||
|
|||||||
Reference in New Issue
Block a user