Enable native ModelOpt quantization support (1/3) (#7149)

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
Zhiyu
2025-10-06 13:24:15 -07:00
committed by GitHub
parent eb30b888db
commit 155cbb51f0
11 changed files with 464 additions and 42 deletions

View File

@@ -17,7 +17,7 @@ import logging
import math
import os
from enum import Enum, IntEnum, auto
from typing import List, Optional, Set, Union
from typing import Dict, List, Optional, Set, Union
import torch
from transformers import PretrainedConfig
@@ -85,6 +85,7 @@ class ModelConfig:
enable_multimodal: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
modelopt_quant: Optional[Union[str, Dict]] = None,
override_config_file: Optional[str] = None,
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
@@ -94,6 +95,7 @@ class ModelConfig:
self.model_path = model_path
self.revision = revision
self.quantization = quantization
self.modelopt_quant = modelopt_quant
self.is_draft_model = is_draft_model
self.model_impl = model_impl
@@ -209,6 +211,7 @@ class ModelConfig:
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
modelopt_quant=server_args.modelopt_quant,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
**kwargs,
@@ -477,54 +480,52 @@ class ModelConfig:
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
is_local = os.path.exists(self.model_path)
modelopt_quant_config = {"quant_method": "modelopt"}
if not is_local:
import huggingface_hub
try:
from huggingface_hub import HfApi
from huggingface_hub import HfApi, hf_hub_download
hf_api = HfApi()
def check_hf_quant_config():
return hf_api.file_exists(
self.model_path, "hf_quant_config.json"
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
# Download and parse the quantization config for remote models
quant_config_file = hf_hub_download(
repo_id=self.model_path,
filename="hf_quant_config.json",
revision=self.revision,
)
# Retry HF API call up to 3 times
file_exists = retry(
check_hf_quant_config,
max_retry=2,
initial_delay=1.0,
max_delay=5.0,
)
if file_exists:
quant_cfg = modelopt_quant_config
with open(quant_config_file) as f:
quant_config_dict = json.load(f)
quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
except huggingface_hub.errors.OfflineModeIsEnabled:
logger.warning(
"Offline mode is enabled, skipping hf_quant_config.json check"
)
except Exception as e:
logger.warning(
f"Failed to check hf_quant_config.json: {self.model_path} {e}"
)
pass
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
quant_config_file = os.path.join(
self.model_path, "hf_quant_config.json"
)
with open(quant_config_file) as f:
quant_config_dict = json.load(f)
json_quant_configs = quant_config_dict["quantization"]
quant_algo = json_quant_configs.get("quant_algo", None)
if quant_algo == "MIXED_PRECISION":
quant_cfg = {"quant_method": "w4afp8"}
else:
quant_cfg = modelopt_quant_config
quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
return quant_cfg
def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> dict:
"""Parse ModelOpt quantization config and return the appropriate quant_method."""
json_quant_configs = quant_config_dict["quantization"]
quant_algo = json_quant_configs.get("quant_algo", None)
if quant_algo == "MIXED_PRECISION":
return {"quant_method": "w4afp8"}
elif quant_algo and ("FP4" in quant_algo or "NVFP4" in quant_algo):
return {"quant_method": "modelopt_fp4"}
elif quant_algo and "FP8" in quant_algo:
return {"quant_method": "modelopt_fp8"}
else:
# Default to FP8 for backward compatibility
return {"quant_method": "modelopt_fp8"}
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
@@ -543,7 +544,8 @@ class ModelConfig:
optimized_quantization_methods = [
"fp8",
"marlin",
"modelopt",
"modelopt_fp8",
"modelopt_fp4",
"gptq_marlin_24",
"gptq_marlin",
"awq_marlin",

View File

@@ -0,0 +1,11 @@
"""
ModelOpt related constants
"""
QUANT_CFG_CHOICES = {
"fp8": "FP8_DEFAULT_CFG",
"int4_awq": "INT4_AWQ_CFG", # TODO: add support for int4_awq
"w4a8_awq": "W4A8_AWQ_BETA_CFG", # TODO: add support for w4a8_awq
"nvfp4": "NVFP4_DEFAULT_CFG",
"nvfp4_awq": "NVFP4_AWQ_LITE_CFG", # TODO: add support for nvfp4_awq
}

View File

@@ -72,7 +72,7 @@ if TYPE_CHECKING:
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"fp8": Fp8Config,
"blockwise_int8": BlockInt8Config,
"modelopt": ModelOptFp8Config,
"modelopt_fp8": ModelOptFp8Config,
"modelopt_fp4": ModelOptFp4Config,
"w8a8_int8": W8A8Int8Config,
"w8a8_fp8": W8A8Fp8Config,

View File

@@ -113,7 +113,7 @@ class ModelOptFp8Config(QuantizationConfig):
@classmethod
def get_name(cls) -> str:
return "modelopt"
return "modelopt_fp8"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:

View File

@@ -880,7 +880,7 @@ class ModelRunner:
load_config = LoadConfig(load_format=load_format)
# Only support DefaultModelLoader for now
loader = get_model_loader(load_config)
loader = get_model_loader(load_config, self.model_config)
if not isinstance(loader, DefaultModelLoader):
message = f"Failed to get model loader: {loader}."
return False, message

View File

@@ -24,7 +24,7 @@ def get_model(
load_config: LoadConfig,
device_config: DeviceConfig,
) -> nn.Module:
loader = get_model_loader(load_config)
loader = get_model_loader(load_config, model_config)
return loader.load_model(
model_config=model_config,
device_config=device_config,

View File

@@ -37,10 +37,22 @@ import numpy as np
import requests
import safetensors.torch
import torch
# Try to import accelerate (optional dependency)
try:
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import get_max_memory
HAS_ACCELERATE = True
except ImportError:
HAS_ACCELERATE = False
infer_auto_device_map = None
init_empty_weights = None
get_max_memory = None
from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from tqdm.auto import tqdm
from transformers import AutoModelForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
@@ -54,6 +66,8 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
trigger_transferring_weights_request,
)
@@ -62,6 +76,11 @@ from sglang.srt.model_loader.utils import (
post_load_weights,
set_default_torch_dtype,
)
# Constants for memory management
DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
)
from sglang.srt.model_loader.weight_utils import (
_BAR_FORMAT,
default_weight_loader,
@@ -94,6 +113,8 @@ if TYPE_CHECKING:
from sglang.srt.layers.quantization.base_config import QuantizationConfig
_is_npu = is_npu()
# ModelOpt: QUANT_CFG_CHOICES is imported from modelopt_utils.py
# which contains the complete mapping of quantization config choices
@contextmanager
@@ -477,12 +498,78 @@ class DefaultModelLoader(BaseModelLoader):
model_config.model_path, model_config.revision, fall_back_to_pt=True
)
def _load_modelopt_base_model(self, model_config: ModelConfig) -> nn.Module:
"""Load and prepare the base model for ModelOpt quantization.
This method handles the common model loading logic shared between
DefaultModelLoader (conditional) and ModelOptModelLoader (dedicated).
"""
if not HAS_ACCELERATE:
raise ImportError(
"accelerate is required for ModelOpt quantization. "
"Please install it with: pip install accelerate"
)
hf_config = AutoConfig.from_pretrained(
model_config.model_path, trust_remote_code=True
)
with init_empty_weights():
torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
model = AutoModelForCausalLM.from_config(
hf_config, torch_dtype=torch_dtype, trust_remote_code=True
)
max_memory = get_max_memory()
inferred_device_map = infer_auto_device_map(model, max_memory=max_memory)
on_cpu = "cpu" in inferred_device_map.values()
model_kwargs = {"torch_dtype": "auto"}
device_map = "auto"
if on_cpu:
for device in max_memory.keys():
if isinstance(device, int):
max_memory[device] *= DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION
logger.warning(
"Model does not fit to the GPU mem. "
f"We apply the following memory limit for calibration: \n{max_memory}\n"
f"If you hit GPU OOM issue, please adjust the memory fraction "
f"(currently {DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION}) or "
"reduce the calibration `batch_size` manually."
)
model_kwargs["max_memory"] = max_memory
model = AutoModelForCausalLM.from_pretrained(
model_config.model_path,
device_map=device_map,
**model_kwargs,
trust_remote_code=True,
)
logger.info(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
quant_choice_str = model_config.modelopt_quant
if not isinstance(quant_choice_str, str):
raise TypeError(
f"modelopt_quant must be a string preset key (e.g., 'fp8'), "
f"got {type(quant_choice_str)}"
)
return model
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
# Load base model using shared method
model = self._load_modelopt_base_model(model_config)
# Note: DefaultModelLoader doesn't do additional quantization processing
# For full ModelOpt quantization, use ModelOptModelLoader
return model.eval()
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
@@ -491,9 +578,9 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config,
)
self.load_weights_and_postprocess(
model, self._get_all_weights(model_config, model), target_device
)
self.load_weights_and_postprocess(
model, self._get_all_weights(model_config, model), target_device
)
return model.eval()
@@ -1668,9 +1755,103 @@ def load_model_with_cpu_quantization(
return model.eval()
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
class ModelOptModelLoader(DefaultModelLoader):
"""
Model loader that applies NVIDIA Model Optimizer quantization
"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
# Any ModelOpt specific initialization if needed
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
) -> nn.Module:
logger.info("ModelOptModelLoader: Loading base model...")
# Use shared method from parent class to load base model
model = self._load_modelopt_base_model(model_config)
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
try:
import modelopt.torch.quantization as mtq
from modelopt.torch.utils.dataset_utils import create_forward_loop
except ImportError:
logger.error(
"NVIDIA Model Optimizer (modelopt) library not found. "
"Please install it to use 'modelopt_quant' feature."
)
raise
quant_choice_str = model_config.modelopt_quant
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
if not quant_cfg_name:
raise ValueError(
f"Invalid modelopt_quant choice: '{quant_choice_str}'. "
f"Available choices in QUANT_CFG_CHOICES: {list(QUANT_CFG_CHOICES.keys())}. "
"Ensure QUANT_CFG_CHOICES is correctly defined with mappings to "
"attribute names of config objects in modelopt.torch.quantization."
)
try:
# getattr will fetch the config object, e.g., mtq.FP8_DEFAULT_CFG
quant_cfg = getattr(mtq, quant_cfg_name)
except AttributeError:
raise AttributeError(
f"ModelOpt quantization config attribute '{quant_cfg_name}' "
f"(from choice '{quant_choice_str}') not found in modelopt.torch.quantization module. "
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
)
# For now, assume no calibration. Calibration setup is a separate, more complex step.
use_calibration = False # This would ideally be a configurable parameter
calib_dataloader = None # This would need to be provided/configured
calibrate_loop = (
create_forward_loop(dataloader=calib_dataloader)
if use_calibration
else None
)
if use_calibration and calib_dataloader is None:
logger.warning(
"ModelOpt calibration requested but no calib_dataloader provided. "
"Proceeding without calibration. Quantization accuracy may be affected."
)
logger.info(
f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}"
)
try:
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
logger.info("Model successfully quantized with ModelOpt.")
except Exception as e:
logger.error(f"Error during ModelOpt mtq.quantize call: {e}")
raise
mtq.print_quant_summary(model)
return model.eval()
def get_model_loader(
load_config: LoadConfig, model_config: Optional[ModelConfig] = None
) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if (
model_config
and hasattr(model_config, "modelopt_quant")
and model_config.modelopt_quant
):
logger.info("Using ModelOptModelLoader due to 'modelopt_quant' config.")
return ModelOptModelLoader(load_config)
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)

View File

@@ -226,6 +226,9 @@ def get_quant_config(
return ModelOptFp4Config.from_config(config)
else:
return quant_cls.from_config(config)
elif model_config.quantization == "modelopt_fp8":
if config["producer"]["name"] == "modelopt_fp8":
return quant_cls.from_config(config)
else:
raise ValueError(
f"Unsupported quantization config"

View File

@@ -20,7 +20,7 @@ import logging
import os
import random
import tempfile
from typing import List, Literal, Optional, Union
from typing import Dict, List, Literal, Optional, Union
from sglang.srt.connector import ConnectorType
from sglang.srt.function_call.function_call_parser import FunctionCallParser
@@ -162,6 +162,7 @@ class ServerArgs:
load_format: str = "auto"
model_loader_extra_config: str = "{}"
trust_remote_code: bool = False
modelopt_quant: Optional[Union[str, Dict]] = None
context_length: Optional[int] = None
is_embedding: bool = False
enable_multimodal: Optional[bool] = None
@@ -1455,6 +1456,14 @@ class ServerArgs:
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
"default to 1.0, which may cause accuracy issues. ",
)
parser.add_argument(
"--modelopt-quant",
type=str,
default=ServerArgs.modelopt_quant,
help="The ModelOpt quantization configuration. "
"Supported values: 'fp8', 'int4_awq', 'w4a8_awq', 'nvfp4', 'nvfp4_awq'. "
"This requires the NVIDIA Model Optimizer library to be installed: pip install nvidia-modelopt",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,