Refactor flashinfer logic for deepseek v3 and fix accuracy bug (#3785)

This commit is contained in:
Baizhou Zhang
2025-02-24 04:07:25 -08:00
committed by GitHub
parent 27a46317b6
commit b110084654
4 changed files with 565 additions and 19 deletions

View File

@@ -14,6 +14,7 @@
import json
import logging
import math
from enum import IntEnum, auto
from typing import List, Optional, Set, Union
@@ -103,7 +104,20 @@ class ModelConfig:
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_config.kv_lora_rank
self.qk_nope_head_dim = self.hf_config.qk_nope_head_dim
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
self.v_head_dim = self.hf_config.v_head_dim
# Handle rope scaling with yarn
self.scaling = 1 / math.sqrt(self.qk_nope_head_dim + self.qk_rope_head_dim)
if self.hf_config.rope_scaling:
mscale_all_dim = self.hf_config.rope_scaling.get(
"mscale_all_dim", False
)
scaling_factor = self.hf_config.rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
self.head_dim = 128
self.attention_arch = AttentionArch.MLA
@@ -414,3 +428,9 @@ def is_multimodal_model(model_architectures: List[str]):
def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
if scale <= 1:
return 1.0
return 0.1 * mscale * math.log(scale) + 1.0