Refactor flashinfer logic for deepseek v3 and fix accuracy bug (#3785)
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from enum import IntEnum, auto
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
@@ -103,7 +104,20 @@ class ModelConfig:
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_config.kv_lora_rank
|
||||
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
|
||||
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
|
||||
self.v_head_dim = self.hf_config.v_head_dim
|
||||
|
||||
# Handle rope scaling with yarn
|
||||
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
|
||||
if self.hf_config.rope_scaling:
|
||||
mscale_all_dim = self.hf_config.rope_scaling.get(
|
||||
"mscale_all_dim", False
|
||||
)
|
||||
scaling_factor = self.hf_config.rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
|
||||
self.head_dim = 128
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
@@ -414,3 +428,9 @@ def is_multimodal_model(model_architectures: List[str]):
|
||||
|
||||
def is_encoder_decoder_model(model_architectures: List[str]):
|
||||
return "MllamaForConditionalGeneration" in model_architectures
|
||||
|
||||
|
||||
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * mscale * math.log(scale) + 1.0
|
||||
|
||||
Reference in New Issue
Block a user