Fix Deepseek R1 0528 FP4 tensor name mismatch issue during weights loading. (#7164)
This commit is contained in:
@@ -34,6 +34,7 @@ from sglang.srt.configs.load_config import LoadConfig
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
||||
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
||||
from sglang.srt.utils import print_warning_once
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -206,7 +207,10 @@ def get_quant_config(
|
||||
config["adapter_name_or_path"] = model_name_or_path
|
||||
elif model_config.quantization == "modelopt":
|
||||
if config["producer"]["name"] == "modelopt":
|
||||
return quant_cls.from_config(config)
|
||||
if "FP4" in config["quantization"]["quant_algo"]:
|
||||
return ModelOptFp4Config.from_config(config)
|
||||
else:
|
||||
return quant_cls.from_config(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization config"
|
||||
|
||||
@@ -1926,6 +1926,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
if (
|
||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||
and hasattr(self.quant_config, "weight_block_size")
|
||||
and self.quant_config.weight_block_size is not None
|
||||
):
|
||||
self._weight_requant_ue8m0()
|
||||
|
||||
@@ -2158,12 +2160,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
"k_scale" in name or "v_scale" in name
|
||||
) and name not in params_dict:
|
||||
# modelopt attn kv scale is named differently
|
||||
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
||||
name = name.replace("_proj", "attn_mqa")
|
||||
else:
|
||||
logger.warning(
|
||||
f"Unknown scale found in checkpoint: {name}"
|
||||
)
|
||||
for scale in ["k_scale", "v_scale"]:
|
||||
if scale in name:
|
||||
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
|
||||
Reference in New Issue
Block a user