Enable native ModelOpt quantization support (3/3) (#10154)

Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
Zhiyu
2025-10-21 21:44:29 -07:00
committed by GitHub
parent 4b65ed42cc
commit 80b2b3207a
16 changed files with 1528 additions and 39 deletions

View File

@@ -75,12 +75,7 @@ dependencies = [
]
[project.optional-dependencies]
tracing = [
"opentelemetry-api",
"opentelemetry-exporter-otlp",
"opentelemetry-exporter-otlp-proto-grpc",
"opentelemetry-sdk",
]
modelopt = ["nvidia-modelopt"]
test = [
"accelerate",
"expecttest",
@@ -107,6 +102,12 @@ cu130_all = [
"sglang[decord]",
"sglang[cu130]"
]
tracing = [
"opentelemetry-api",
"opentelemetry-exporter-otlp",
"opentelemetry-exporter-otlp-proto-grpc",
"opentelemetry-sdk",
]
# To be deprecated in 2 weeks
blackwell = ["sglang[dev]"]

View File

@@ -6,6 +6,7 @@ from typing import List, Optional, Union
import orjson
from sglang.srt.configs.modelopt_config import ModelOptConfig
from sglang.srt.utils import is_hip
logger = logging.getLogger(__name__)
@@ -51,6 +52,11 @@ class LoadConfig:
decryption_key_file: If set, decrypts the output files with a password read
from this file (after PBKDF2).
decrypt_max_concurrency: The maximum number of concurrent processes to decrypt the safetensor files. -1 means no limit.
# ModelOpt-specific loading options
modelopt_checkpoint_restore_path: Optional[str] = None
modelopt_checkpoint_save_path: Optional[str] = None
modelopt_export_path: Optional[str] = None
"""
load_format: Union[str, LoadFormat] = LoadFormat.AUTO
@@ -64,6 +70,14 @@ class LoadConfig:
remote_instance_weight_loader_seed_instance_service_port: Optional[int] = None
remote_instance_weight_loader_send_weights_group_ports: Optional[List[int]] = None
# ModelOpt-specific loading options
modelopt_checkpoint_restore_path: Optional[str] = None
modelopt_checkpoint_save_path: Optional[str] = None
modelopt_export_path: Optional[str] = None
# ModelOpt configuration object
modelopt_config: Optional[ModelOptConfig] = None
def __post_init__(self):
model_loader_extra_config = self.model_loader_extra_config or {}
if isinstance(model_loader_extra_config, str):
@@ -78,6 +92,14 @@ class LoadConfig:
else:
self.ignore_patterns = ["original/**/*"]
# Create ModelOptConfig if not provided
if self.modelopt_config is None:
self.modelopt_config = ModelOptConfig(
checkpoint_restore_path=self.modelopt_checkpoint_restore_path,
checkpoint_save_path=self.modelopt_checkpoint_save_path,
export_path=self.modelopt_export_path,
)
def _verify_load_format(self) -> None:
if not isinstance(self.load_format, str):
return

View File

@@ -17,7 +17,7 @@ import logging
import math
import os
from enum import Enum, IntEnum, auto
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, List, Optional, Set, Union
import torch
from transformers import PretrainedConfig
@@ -89,7 +89,6 @@ class ModelConfig:
enable_multimodal: Optional[bool] = None,
dtype: str = "auto",
quantization: Optional[str] = None,
modelopt_quant: Optional[Union[str, Dict]] = None,
override_config_file: Optional[str] = None,
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[
@@ -97,15 +96,19 @@ class ModelConfig:
] = None, # TODO: remove this, it is not a model config
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
sampling_defaults: str = "openai",
quantize_and_serve: bool = False,
) -> None:
# Parse args
self.model_path = model_path
self.revision = revision
self.quantization = quantization
self.modelopt_quant = modelopt_quant
self.is_draft_model = is_draft_model
self.model_impl = model_impl
self.sampling_defaults = sampling_defaults
self.quantize_and_serve = quantize_and_serve
# Validate quantize_and_serve configuration
self._validate_quantize_and_serve_config()
# Get hf config
self._maybe_pull_model_tokenizer_from_remote()
@@ -219,10 +222,10 @@ class ModelConfig:
enable_multimodal=server_args.enable_multimodal,
dtype=server_args.dtype,
quantization=server_args.quantization,
modelopt_quant=server_args.modelopt_quant,
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
model_impl=server_args.model_impl,
sampling_defaults=server_args.sampling_defaults,
quantize_and_serve=server_args.quantize_and_serve,
**kwargs,
)
@@ -547,6 +550,56 @@ class ModelConfig:
# Default to FP8 for backward compatibility
return {"quant_method": "modelopt_fp8"}
def _is_already_quantized(self) -> bool:
"""Check if the model is already quantized based on config files."""
# Check for HuggingFace quantization config
from sglang.srt.utils import has_hf_quant_config
return has_hf_quant_config(self.model_path)
def _get_modelopt_quant_type(self) -> str:
"""Extract ModelOpt quantization type from unified quantization flag."""
if self.quantization == "modelopt_fp8":
return "fp8"
elif self.quantization == "modelopt_fp4":
return "nvfp4"
elif self.quantization == "modelopt":
# Auto-detect from model config
quant_cfg = self._parse_quant_hf_config()
if quant_cfg:
quant_method = quant_cfg.get("quant_method", "").lower()
if "fp4" in quant_method:
return "fp4"
elif "fp8" in quant_method:
return "fp8"
# Default to fp8 if can't detect
return "fp8"
else:
return "fp8" # Default fallback
def _validate_quantize_and_serve_config(self):
"""Validate quantize_and_serve configuration."""
if not self.quantize_and_serve:
return
# Check if ModelOpt quantization is specified
modelopt_quantization_specified = self.quantization in [
"modelopt",
"modelopt_fp8",
"modelopt_fp4",
]
if not modelopt_quantization_specified:
raise ValueError("quantize_and_serve requires ModelOpt quantization")
# quantize_and_serve is disabled due to compatibility issues
raise NotImplementedError(
"quantize_and_serve functionality is currently disabled due to compatibility issues. "
"Please use the separate quantize-then-deploy workflow instead. "
"Step 1: Quantize and export model. "
"Step 2: Deploy the exported model."
)
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]

View File

@@ -0,0 +1,30 @@
# Configuration for NVIDIA ModelOpt quantization integration
from dataclasses import dataclass
from typing import Optional
@dataclass
class ModelOptConfig:
"""Configuration for NVIDIA ModelOpt quantization operations.
This configuration class holds parameters for ModelOpt quantization,
checkpoint management, and model export operations.
Args:
quant: Quantization method/type (e.g., "fp8", "fp4")
checkpoint_restore_path: Path to restore ModelOpt checkpoint from
checkpoint_save_path: Path to save ModelOpt checkpoint to
export_path: Path to export quantized model in HuggingFace format
quantize_and_serve: Whether to quantize and serve in one step
"""
quant: Optional[str] = None
checkpoint_restore_path: Optional[str] = None
checkpoint_save_path: Optional[str] = None
export_path: Optional[str] = None
quantize_and_serve: bool = False
def __post_init__(self):
"""Validate configuration after initialization."""
# Add any validation logic if needed
pass

View File

@@ -72,6 +72,7 @@ if TYPE_CHECKING:
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"fp8": Fp8Config,
"blockwise_int8": BlockInt8Config,
"modelopt": ModelOptFp8Config, # Auto-detect, defaults to FP8
"modelopt_fp8": ModelOptFp8Config,
"modelopt_fp4": ModelOptFp4Config,
"w8a8_int8": W8A8Int8Config,

View File

@@ -161,6 +161,26 @@ class QuantizationConfig(ABC):
"""
return None
@classmethod
def _modelopt_override_quantization_method(
cls, hf_quant_config, user_quant
) -> Optional[str]:
"""Shared ModelOpt quantization method override logic."""
if hf_quant_config is None:
return None
# Check if this is a ModelOpt config
quant_algo = hf_quant_config.get("quant_algo", "").upper()
# If user specified generic "modelopt", auto-detect the specific method
if user_quant == "modelopt":
if "FP8" in quant_algo:
return "modelopt_fp8"
elif "NVFP4" in quant_algo or "FP4" in quant_algo:
return "modelopt_fp4"
return None
@staticmethod
def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any:
"""Get a value from the model's quantization config."""

View File

@@ -111,6 +111,11 @@ class ModelOptFp8Config(QuantizationConfig):
"Detected ModelOpt FP8 checkpoint. The format is experimental and subject to change."
)
@classmethod
def override_quantization_method(cls, hf_quant_config, user_quant):
"""Override quantization method based on the model's config."""
return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
@classmethod
def get_name(cls) -> str:
return "modelopt_fp8"
@@ -527,6 +532,11 @@ class ModelOptFp4Config(QuantizationConfig):
self.kv_cache_quant_algo = kv_cache_quant_algo
self.exclude_modules = exclude_modules
@classmethod
def override_quantization_method(cls, hf_quant_config, user_quant):
"""Override quantization method based on the model's config."""
return cls._modelopt_override_quantization_method(hf_quant_config, user_quant)
@classmethod
def get_name(cls) -> str:
return "modelopt_fp4"
@@ -608,7 +618,16 @@ class ModelOptFp4Config(QuantizationConfig):
else:
kv_cache_quant_algo = "auto"
group_size = ModelOptFp4Config.common_group_size(config)
group_size = config.get("group_size")
# If group_size is not at top level, try to extract from config_groups
if group_size is None:
config_groups = config.get("config_groups", {})
if config_groups:
# Get group_size from the first group's weights config
first_group = next(iter(config_groups.values()), {})
weights_config = first_group.get("weights", {})
group_size = weights_config.get("group_size")
exclude_modules = config.get("ignore", [])
else:
# Fall back to nested format (hf_quant_config.json - legacy format)
@@ -634,15 +653,15 @@ class ModelOptFp4Config(QuantizationConfig):
)
is_checkpoint_nvfp4_serialized = "NVFP4" in quant_method
if not (group_size and kv_cache_quant_algo) or exclude_modules is None:
if group_size is None or exclude_modules is None:
logger.warning(
f"group_size: {group_size},"
f"kv_cache_quant_algo: {kv_cache_quant_algo},"
f"exclude_modules: {exclude_modules}"
)
raise ValueError(
"NVFP4 quantization requires group size and "
"kv_cache_quant_algo specified in the quantization config"
"NVFP4 quantization requires group_size and exclude_modules "
"specified in the quantization config"
)
return cls(
is_checkpoint_nvfp4_serialized,

View File

@@ -828,6 +828,16 @@ class ModelRunner:
set_cuda_arch()
# Prepare the model config
from sglang.srt.configs.modelopt_config import ModelOptConfig
modelopt_config = ModelOptConfig(
quant=self.server_args.modelopt_quant,
checkpoint_restore_path=self.server_args.modelopt_checkpoint_restore_path,
checkpoint_save_path=self.server_args.modelopt_checkpoint_save_path,
export_path=self.server_args.modelopt_export_path,
quantize_and_serve=self.server_args.quantize_and_serve,
)
self.load_config = LoadConfig(
load_format=self.server_args.load_format,
download_dir=self.server_args.download_dir,
@@ -836,6 +846,7 @@ class ModelRunner:
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
modelopt_config=modelopt_config,
)
if self.device == "cpu":
self.model_config = adjust_config_with_unaligned_cpu_tp(

View File

@@ -538,12 +538,21 @@ class DefaultModelLoader(BaseModelLoader):
**model_kwargs,
trust_remote_code=True,
)
rank0_log(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
# Handle both legacy modelopt_quant and unified quantization flags
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
# Legacy approach
quant_choice_str = model_config.modelopt_quant
rank0_log(f"ModelOpt quantization requested (legacy): {quant_choice_str}")
else:
# Unified approach - extract quantization type
quant_choice_str = model_config._get_modelopt_quant_type()
rank0_log(
f"ModelOpt quantization requested (unified): {model_config.quantization} -> {quant_choice_str}"
)
quant_choice_str = model_config.modelopt_quant
if not isinstance(quant_choice_str, str):
raise TypeError(
f"modelopt_quant must be a string preset key (e.g., 'fp8'), "
f"Quantization type must be a string (e.g., 'fp8'), "
f"got {type(quant_choice_str)}"
)
@@ -1764,6 +1773,7 @@ class ModelOptModelLoader(DefaultModelLoader):
quant_cfg,
quantized_ckpt_restore_path: str | None = None,
quantized_ckpt_save_path: str | None = None,
export_path: str | None = None,
) -> None:
"""
Set up ModelOpt quantization for the given model.
@@ -1774,6 +1784,7 @@ class ModelOptModelLoader(DefaultModelLoader):
quant_cfg: The quantization configuration
quantized_ckpt_restore_path: Path to restore quantized checkpoint from
quantized_ckpt_save_path: Path to save quantized checkpoint to
export_path: Path to export the quantized model in HuggingFace format
Raises:
ImportError: If ModelOpt is not available
@@ -1798,6 +1809,9 @@ class ModelOptModelLoader(DefaultModelLoader):
rank0_log(
f"Restored quantized model from {quantized_ckpt_restore_path}"
)
# Export model if path provided (even when restoring from checkpoint)
self._maybe_export_modelopt(model, export_path)
return
except Exception as e:
logger.warning(
@@ -1844,9 +1858,75 @@ class ModelOptModelLoader(DefaultModelLoader):
f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}"
)
# Export model if path provided
self._maybe_export_modelopt(model, export_path)
except Exception as e:
raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e
def _maybe_export_modelopt(self, model, export_path: str | None) -> None:
"""Export model to HuggingFace format if export_path is provided."""
if export_path:
try:
# Get the original model path from the model config
original_model_path = getattr(self, "_original_model_path", None)
self._export_modelopt_checkpoint(
model, export_path, original_model_path
)
rank0_log(
f"Quantized model exported to HuggingFace format at {export_path}"
)
except Exception as e:
rank0_log(
f"Warning: Failed to export quantized model to {export_path}: {e}"
)
def _export_modelopt_checkpoint(
self,
model,
export_path: str,
model_path: str = None,
trust_remote_code: bool = True,
) -> None:
"""
Export the quantized model to HuggingFace format using ModelOpt export API.
Args:
model: The quantized model to export
export_path: Directory path to export the model to
model_path: Path to the original model (for tokenizer export)
trust_remote_code: Whether to trust remote code for tokenizer loading
Raises:
ImportError: If ModelOpt export functionality is not available
Exception: If export fails
"""
try:
from modelopt.torch.export import export_hf_checkpoint
from transformers import AutoTokenizer
except ImportError as e:
raise ImportError(
"ModelOpt export functionality is not available. "
"Please ensure you have the latest version of modelopt installed."
) from e
# Create export directory if it doesn't exist
os.makedirs(export_path, exist_ok=True)
# Export the quantized model
export_hf_checkpoint(model, export_dir=export_path)
# Export the tokenizer if model_path is provided
if model_path:
try:
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=trust_remote_code
)
tokenizer.save_pretrained(export_path)
rank0_log(f"Tokenizer exported to {export_path}")
except Exception as e:
rank0_log(f"Warning: Failed to export tokenizer: {e}")
def load_model(
self,
*,
@@ -1856,28 +1936,52 @@ class ModelOptModelLoader(DefaultModelLoader):
logger.info("ModelOptModelLoader: Loading base model...")
# Use shared method from parent class to load base model
# Store the original model path for tokenizer export
self._original_model_path = model_config.model_path
# Check if model is already quantized
if model_config._is_already_quantized():
logger.info("Model is already quantized, loading directly...")
# Use default loading for pre-quantized models
return super().load_model(
model_config=model_config, device_config=device_config
)
# TODO: Quantize-and-serve mode has been disabled at the ModelConfig level
# All quantization now uses the standard workflow (quantize + export/save)
logger.info("Standard quantization mode: Will quantize and export/save")
return self._standard_quantization_workflow(model_config, device_config)
def _standard_quantization_workflow(
self, model_config: ModelConfig, device_config: DeviceConfig
) -> nn.Module:
"""Standard quantization workflow: quantize, save checkpoint, export, then return model."""
# Use shared method from parent class to load base model for quantization
model = self._load_modelopt_base_model(model_config)
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
# Import ModelOpt modules
try:
import modelopt.torch.quantization as mtq
except ImportError:
logger.error(
"NVIDIA Model Optimizer (modelopt) library not found. "
"Please install it to use 'modelopt_quant' feature."
"Please install it to use ModelOpt quantization."
)
raise
quant_choice_str = model_config.modelopt_quant
# Handle both old modelopt_quant and new unified quantization flags
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
# Legacy modelopt_quant flag
quant_choice_str = model_config.modelopt_quant
else:
# Unified quantization flag - extract the type (fp8/fp4)
quant_choice_str = model_config._get_modelopt_quant_type()
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
if not quant_cfg_name:
raise ValueError(
f"Invalid modelopt_quant choice: '{quant_choice_str}'. "
f"Available choices in QUANT_CFG_CHOICES: {list(QUANT_CFG_CHOICES.keys())}. "
"Ensure QUANT_CFG_CHOICES is correctly defined with mappings to "
"attribute names of config objects in modelopt.torch.quantization."
f"Invalid quantization choice: '{quant_choice_str}'. "
f"Available choices: {list(QUANT_CFG_CHOICES.keys())}"
)
try:
@@ -1885,20 +1989,27 @@ class ModelOptModelLoader(DefaultModelLoader):
quant_cfg = getattr(mtq, quant_cfg_name)
except AttributeError:
raise AttributeError(
f"ModelOpt quantization config attribute '{quant_cfg_name}' "
f"(from choice '{quant_choice_str}') not found in modelopt.torch.quantization module. "
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
f"ModelOpt quantization config '{quant_cfg_name}' not found. "
"Please verify the ModelOpt library installation."
)
logger.info(
f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}"
f"Quantizing model with ModelOpt using config: mtq.{quant_cfg_name}"
)
quantized_ckpt_restore_path = model_config.modelopt_checkpoint_restore_path
quantized_ckpt_save_path = model_config.modelopt_checkpoint_save_path
# Get ModelOpt configuration from LoadConfig
modelopt_config = self.load_config.modelopt_config
quantized_ckpt_restore_path = (
modelopt_config.checkpoint_restore_path if modelopt_config else None
)
quantized_ckpt_save_path = (
modelopt_config.checkpoint_save_path if modelopt_config else None
)
export_path = modelopt_config.export_path if modelopt_config else None
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_path, use_fast=True
)
try:
self._setup_modelopt_quantization(
model,
@@ -1906,6 +2017,7 @@ class ModelOptModelLoader(DefaultModelLoader):
quant_cfg,
quantized_ckpt_restore_path=quantized_ckpt_restore_path,
quantized_ckpt_save_path=quantized_ckpt_save_path,
export_path=export_path,
)
except Exception as e:
logger.warning(f"ModelOpt quantization failed: {e}")
@@ -1919,12 +2031,27 @@ def get_model_loader(
) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if model_config and (
(hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant)
or model_config.quantization in ["modelopt_fp8", "modelopt_fp4", "modelopt"]
):
logger.info("Using ModelOptModelLoader due to ModelOpt quantization config.")
return ModelOptModelLoader(load_config)
# Use ModelOptModelLoader for unified quantization flags
if (
model_config
and hasattr(model_config, "modelopt_quant")
and model_config.modelopt_quant
and hasattr(model_config, "quantization")
and model_config.quantization in ["modelopt_fp8", "modelopt_fp4"]
):
logger.info("Using ModelOptModelLoader due to 'modelopt_quant' config.")
if model_config._is_already_quantized():
logger.info(
f"Using ModelOptModelLoader for pre-quantized model: {model_config.quantization}"
)
else:
logger.info(
f"Using ModelOptModelLoader for quantization: {model_config.quantization}"
)
return ModelOptModelLoader(load_config)
if isinstance(load_config.load_format, type):

View File

@@ -83,6 +83,7 @@ QUANTIZATION_CHOICES = [
"bitsandbytes",
"gguf",
"modelopt",
"modelopt_fp8",
"modelopt_fp4",
"petit_nvfp4",
"w8a8_int8",
@@ -192,6 +193,8 @@ class ServerArgs:
modelopt_quant: Optional[Union[str, Dict]] = None
modelopt_checkpoint_restore_path: Optional[str] = None
modelopt_checkpoint_save_path: Optional[str] = None
modelopt_export_path: Optional[str] = None
quantize_and_serve: bool = False
context_length: Optional[int] = None
is_embedding: bool = False
enable_multimodal: Optional[bool] = None
@@ -1743,6 +1746,22 @@ class ServerArgs:
help="Path to save the ModelOpt quantized checkpoint after quantization. "
"This allows reusing the quantized model in future runs.",
)
parser.add_argument(
"--modelopt-export-path",
type=str,
default=ServerArgs.modelopt_export_path,
help="Path to export the quantized model in HuggingFace format after ModelOpt quantization. "
"The exported model can then be used directly with SGLang for inference. "
"If not provided, the model will not be exported.",
)
parser.add_argument(
"--quantize-and-serve",
action="store_true",
default=ServerArgs.quantize_and_serve,
help="Quantize the model with ModelOpt and immediately serve it without exporting. "
"This is useful for development and prototyping. For production, it's recommended "
"to use separate quantization and deployment steps.",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,

View File

@@ -2411,6 +2411,29 @@ def retry(
time.sleep(delay)
def has_hf_quant_config(model_path: str) -> bool:
"""Check if the model path contains hf_quant_config.json file.
Args:
model_path: Path to the model, can be local path or remote URL.
Returns:
True if hf_quant_config.json exists, False otherwise.
"""
if is_remote_url(model_path):
try:
from huggingface_hub import HfApi
hf_api = HfApi()
return hf_api.file_exists(model_path, "hf_quant_config.json")
except Exception:
return False
else:
import os
return os.path.exists(os.path.join(model_path, "hf_quant_config.json"))
def flatten_nested_list(nested_list):
if isinstance(nested_list, list):
return [