Enable native ModelOpt quantization support (1/3) (#7149)

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
Zhiyu
2025-10-06 13:24:15 -07:00
committed by GitHub
parent eb30b888db
commit 155cbb51f0
11 changed files with 464 additions and 42 deletions

View File

@@ -17,7 +17,7 @@ import logging
import math
import os
from enum import Enum, IntEnum, auto
from typing import List, Optional, Set, Union
from typing import Dict, List, Optional, Set, Union
import torch
from transformers import PretrainedConfig
@@ -85,6 +85,7 @@ class ModelConfig:
enable_multimodal: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
modelopt_quant: Optional[Union[str, Dict]] = None,
override_config_file: Optional[str] = None,
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
@@ -94,6 +95,7 @@ class ModelConfig:
self.model_path = model_path
self.revision = revision
self.quantization = quantization
self.modelopt_quant = modelopt_quant
self.is_draft_model = is_draft_model
self.model_impl = model_impl
@@ -209,6 +211,7 @@ class ModelConfig:
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
modelopt_quant=server_args.modelopt_quant,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
**kwargs,
@@ -477,54 +480,52 @@ class ModelConfig:
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
is_local = os.path.exists(self.model_path)
modelopt_quant_config = {"quant_method": "modelopt"}
if not is_local:
import huggingface_hub
try:
from huggingface_hub import HfApi
from huggingface_hub import HfApi, hf_hub_download
hf_api = HfApi()
def check_hf_quant_config():
return hf_api.file_exists(
self.model_path, "hf_quant_config.json"
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
# Download and parse the quantization config for remote models
quant_config_file = hf_hub_download(
repo_id=self.model_path,
filename="hf_quant_config.json",
revision=self.revision,
)
# Retry HF API call up to 3 times
file_exists = retry(
check_hf_quant_config,
max_retry=2,
initial_delay=1.0,
max_delay=5.0,
)
if file_exists:
quant_cfg = modelopt_quant_config
with open(quant_config_file) as f:
quant_config_dict = json.load(f)
quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
except huggingface_hub.errors.OfflineModeIsEnabled:
logger.warning(
"Offline mode is enabled, skipping hf_quant_config.json check"
)
except Exception as e:
logger.warning(
f"Failed to check hf_quant_config.json: {self.model_path} {e}"
)
pass
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
quant_config_file = os.path.join(
self.model_path, "hf_quant_config.json"
)
with open(quant_config_file) as f:
quant_config_dict = json.load(f)
json_quant_configs = quant_config_dict["quantization"]
quant_algo = json_quant_configs.get("quant_algo", None)
if quant_algo == "MIXED_PRECISION":
quant_cfg = {"quant_method": "w4afp8"}
else:
quant_cfg = modelopt_quant_config
quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
return quant_cfg
def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> dict:
"""Parse ModelOpt quantization config and return the appropriate quant_method."""
json_quant_configs = quant_config_dict["quantization"]
quant_algo = json_quant_configs.get("quant_algo", None)
if quant_algo == "MIXED_PRECISION":
return {"quant_method": "w4afp8"}
elif quant_algo and ("FP4" in quant_algo or "NVFP4" in quant_algo):
return {"quant_method": "modelopt_fp4"}
elif quant_algo and "FP8" in quant_algo:
return {"quant_method": "modelopt_fp8"}
else:
# Default to FP8 for backward compatibility
return {"quant_method": "modelopt_fp8"}
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
@@ -543,7 +544,8 @@ class ModelConfig:
optimized_quantization_methods = [
"fp8",
"marlin",
"modelopt",
"modelopt_fp8",
"modelopt_fp4",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",