Enable native ModelOpt quantization support (2/3) (#9991)
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
@@ -86,6 +86,8 @@ class ModelConfig:
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
modelopt_quant: Optional[Union[str, Dict]] = None,
|
||||
modelopt_checkpoint_restore_path: Optional[str] = None,
|
||||
modelopt_checkpoint_save_path: Optional[str] = None,
|
||||
override_config_file: Optional[str] = None,
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
|
||||
@@ -18,7 +18,7 @@ import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, suppress
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -30,7 +30,6 @@ from typing import (
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import huggingface_hub
|
||||
import numpy as np
|
||||
@@ -52,7 +51,7 @@ except ImportError:
|
||||
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from torch import nn
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||
@@ -104,6 +103,7 @@ from sglang.srt.utils import (
|
||||
get_device_capability,
|
||||
is_npu,
|
||||
is_pin_memory_available,
|
||||
rank0_log,
|
||||
set_weight_attrs,
|
||||
)
|
||||
|
||||
@@ -545,7 +545,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
**model_kwargs,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
logger.info(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
|
||||
rank0_log(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
|
||||
|
||||
quant_choice_str = model_config.modelopt_quant
|
||||
if not isinstance(quant_choice_str, str):
|
||||
@@ -1764,6 +1764,96 @@ class ModelOptModelLoader(DefaultModelLoader):
|
||||
super().__init__(load_config)
|
||||
# Any ModelOpt specific initialization if needed
|
||||
|
||||
def _setup_modelopt_quantization(
|
||||
self,
|
||||
model,
|
||||
tokenizer,
|
||||
quant_cfg,
|
||||
quantized_ckpt_restore_path: str | None = None,
|
||||
quantized_ckpt_save_path: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Set up ModelOpt quantization for the given model.
|
||||
|
||||
Args:
|
||||
model: The model to quantize
|
||||
tokenizer: The tokenizer associated with the model
|
||||
quant_cfg: The quantization configuration
|
||||
quantized_ckpt_restore_path: Path to restore quantized checkpoint from
|
||||
quantized_ckpt_save_path: Path to save quantized checkpoint to
|
||||
|
||||
Raises:
|
||||
ImportError: If ModelOpt is not available
|
||||
Exception: If quantization setup fails
|
||||
"""
|
||||
try:
|
||||
import modelopt.torch.opt as mto
|
||||
import modelopt.torch.quantization as mtq
|
||||
from modelopt.torch.quantization.utils import is_quantized
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"ModelOpt is not available. Please install modelopt."
|
||||
) from e
|
||||
|
||||
if is_quantized(model):
|
||||
rank0_log("Model is already quantized, skipping quantization setup.")
|
||||
return
|
||||
# Restore from checkpoint if provided
|
||||
if quantized_ckpt_restore_path:
|
||||
try:
|
||||
mto.restore(model, quantized_ckpt_restore_path)
|
||||
rank0_log(
|
||||
f"Restored quantized model from {quantized_ckpt_restore_path}"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to restore from {quantized_ckpt_restore_path}: {e}"
|
||||
)
|
||||
rank0_log("Proceeding with calibration-based quantization...")
|
||||
|
||||
# Set up calibration-based quantization
|
||||
try:
|
||||
# Left padding tends to work better for batched generation with decoder-only LMs
|
||||
with suppress(Exception):
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
from modelopt.torch.utils.dataset_utils import (
|
||||
create_forward_loop,
|
||||
get_dataset_dataloader,
|
||||
)
|
||||
|
||||
# Create calibration dataloader
|
||||
calib_dataloader = get_dataset_dataloader(
|
||||
dataset_name="cnn_dailymail", # TODO: Consider making this configurable
|
||||
tokenizer=tokenizer,
|
||||
batch_size=36, # TODO: Consider making this configurable
|
||||
num_samples=512, # TODO: Consider making this configurable
|
||||
device=model.device,
|
||||
include_labels=False,
|
||||
)
|
||||
|
||||
calibrate_loop = create_forward_loop(dataloader=calib_dataloader)
|
||||
|
||||
# Apply quantization
|
||||
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
|
||||
|
||||
if get_tensor_model_parallel_rank() == 0:
|
||||
mtq.print_quant_summary(model)
|
||||
|
||||
# Save checkpoint if path provided
|
||||
if quantized_ckpt_save_path:
|
||||
try:
|
||||
mto.save(model, quantized_ckpt_save_path)
|
||||
rank0_log(f"Quantized model saved to {quantized_ckpt_save_path}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
*,
|
||||
@@ -1779,7 +1869,6 @@ class ModelOptModelLoader(DefaultModelLoader):
|
||||
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
|
||||
try:
|
||||
import modelopt.torch.quantization as mtq
|
||||
from modelopt.torch.utils.dataset_utils import create_forward_loop
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"NVIDIA Model Optimizer (modelopt) library not found. "
|
||||
@@ -1808,33 +1897,26 @@ class ModelOptModelLoader(DefaultModelLoader):
|
||||
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
|
||||
)
|
||||
|
||||
# For now, assume no calibration. Calibration setup is a separate, more complex step.
|
||||
use_calibration = False # This would ideally be a configurable parameter
|
||||
calib_dataloader = None # This would need to be provided/configured
|
||||
|
||||
calibrate_loop = (
|
||||
create_forward_loop(dataloader=calib_dataloader)
|
||||
if use_calibration
|
||||
else None
|
||||
)
|
||||
|
||||
if use_calibration and calib_dataloader is None:
|
||||
logger.warning(
|
||||
"ModelOpt calibration requested but no calib_dataloader provided. "
|
||||
"Proceeding without calibration. Quantization accuracy may be affected."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}"
|
||||
)
|
||||
|
||||
quantized_ckpt_restore_path = model_config.modelopt_checkpoint_restore_path
|
||||
quantized_ckpt_save_path = model_config.modelopt_checkpoint_save_path
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_config.model_path, use_fast=True
|
||||
)
|
||||
try:
|
||||
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
|
||||
logger.info("Model successfully quantized with ModelOpt.")
|
||||
self._setup_modelopt_quantization(
|
||||
model,
|
||||
tokenizer,
|
||||
quant_cfg,
|
||||
quantized_ckpt_restore_path=quantized_ckpt_restore_path,
|
||||
quantized_ckpt_save_path=quantized_ckpt_save_path,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ModelOpt mtq.quantize call: {e}")
|
||||
raise
|
||||
mtq.print_quant_summary(model)
|
||||
logger.warning(f"ModelOpt quantization failed: {e}")
|
||||
rank0_log("Proceeding without quantization...")
|
||||
|
||||
return model.eval()
|
||||
|
||||
|
||||
@@ -178,6 +178,8 @@ class ServerArgs:
|
||||
model_loader_extra_config: str = "{}"
|
||||
trust_remote_code: bool = False
|
||||
modelopt_quant: Optional[Union[str, Dict]] = None
|
||||
modelopt_checkpoint_restore_path: Optional[str] = None
|
||||
modelopt_checkpoint_save_path: Optional[str] = None
|
||||
context_length: Optional[int] = None
|
||||
is_embedding: bool = False
|
||||
enable_multimodal: Optional[bool] = None
|
||||
@@ -1504,6 +1506,21 @@ class ServerArgs:
|
||||
"Supported values: 'fp8', 'int4_awq', 'w4a8_awq', 'nvfp4', 'nvfp4_awq'. "
|
||||
"This requires the NVIDIA Model Optimizer library to be installed: pip install nvidia-modelopt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--modelopt-checkpoint-restore-path",
|
||||
type=str,
|
||||
default=ServerArgs.modelopt_checkpoint_restore_path,
|
||||
help="Path to restore a previously saved ModelOpt quantized checkpoint. "
|
||||
"If provided, the quantization process will be skipped and the model "
|
||||
"will be loaded from this checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--modelopt-checkpoint-save-path",
|
||||
type=str,
|
||||
default=ServerArgs.modelopt_checkpoint_save_path,
|
||||
help="Path to save the ModelOpt quantized checkpoint after quantization. "
|
||||
"This allows reusing the quantized model in future runs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user