diff --git a/python/sglang/srt/model_loader/weight_utils.py b/python/sglang/srt/model_loader/weight_utils.py index d1c44e4f7..e61a521e1 100644 --- a/python/sglang/srt/model_loader/weight_utils.py +++ b/python/sglang/srt/model_loader/weight_utils.py @@ -34,6 +34,7 @@ from sglang.srt.configs.load_config import LoadConfig from sglang.srt.configs.model_config import ModelConfig from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config +from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp4Config from sglang.srt.utils import print_warning_once logger = logging.getLogger(__name__) @@ -206,7 +207,10 @@ def get_quant_config( config["adapter_name_or_path"] = model_name_or_path elif model_config.quantization == "modelopt": if config["producer"]["name"] == "modelopt": - return quant_cls.from_config(config) + if "FP4" in config["quantization"]["quant_algo"]: + return ModelOptFp4Config.from_config(config) + else: + return quant_cls.from_config(config) else: raise ValueError( f"Unsupported quantization config" diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 041ecc18c..192ab50ed 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1926,6 +1926,8 @@ class DeepseekV2ForCausalLM(nn.Module): if ( deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0 + and hasattr(self.quant_config, "weight_block_size") + and self.quant_config.weight_block_size is not None ): self._weight_requant_ue8m0() @@ -2158,12 +2160,9 @@ class DeepseekV2ForCausalLM(nn.Module): "k_scale" in name or "v_scale" in name ) and name not in params_dict: # modelopt attn kv scale is named differently - if any(scale in name for scale in ["k_scale", "v_scale"]): - name = name.replace("_proj", "attn_mqa") - else: - logger.warning( - f"Unknown scale found in checkpoint: {name}" - ) + for scale in ["k_scale", "v_scale"]: + if scale in name: + name = name.replace(f"{scale[0]}_proj", "attn_mqa") param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader