Fix Deepseek R1 0528 FP4 tensor name mismatch issue during weights loading. (#7164)
This commit is contained in:
@@ -34,6 +34,7 @@ from sglang.srt.configs.load_config import LoadConfig
|
|||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||||
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
||||||
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config
|
||||||
from sglang.srt.utils import print_warning_once
|
from sglang.srt.utils import print_warning_once
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -206,7 +207,10 @@ def get_quant_config(
|
|||||||
config["adapter_name_or_path"] = model_name_or_path
|
config["adapter_name_or_path"] = model_name_or_path
|
||||||
elif model_config.quantization == "modelopt":
|
elif model_config.quantization == "modelopt":
|
||||||
if config["producer"]["name"] == "modelopt":
|
if config["producer"]["name"] == "modelopt":
|
||||||
return quant_cls.from_config(config)
|
if "FP4" in config["quantization"]["quant_algo"]:
|
||||||
|
return ModelOptFp4Config.from_config(config)
|
||||||
|
else:
|
||||||
|
return quant_cls.from_config(config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unsupported quantization config"
|
f"Unsupported quantization config"
|
||||||
|
|||||||
@@ -1926,6 +1926,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
if (
|
if (
|
||||||
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
||||||
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
||||||
|
and hasattr(self.quant_config, "weight_block_size")
|
||||||
|
and self.quant_config.weight_block_size is not None
|
||||||
):
|
):
|
||||||
self._weight_requant_ue8m0()
|
self._weight_requant_ue8m0()
|
||||||
|
|
||||||
@@ -2158,12 +2160,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
"k_scale" in name or "v_scale" in name
|
"k_scale" in name or "v_scale" in name
|
||||||
) and name not in params_dict:
|
) and name not in params_dict:
|
||||||
# modelopt attn kv scale is named differently
|
# modelopt attn kv scale is named differently
|
||||||
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
for scale in ["k_scale", "v_scale"]:
|
||||||
name = name.replace("_proj", "attn_mqa")
|
if scale in name:
|
||||||
else:
|
name = name.replace(f"{scale[0]}_proj", "attn_mqa")
|
||||||
logger.warning(
|
|
||||||
f"Unknown scale found in checkpoint: {name}"
|
|
||||||
)
|
|
||||||
param = params_dict[name]
|
param = params_dict[name]
|
||||||
weight_loader = getattr(
|
weight_loader = getattr(
|
||||||
param, "weight_loader", default_weight_loader
|
param, "weight_loader", default_weight_loader
|
||||||
|
|||||||
Reference in New Issue
Block a user