diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index ce54ef802..03f72ccdf 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -86,6 +86,8 @@ class ModelConfig: dtype: str = "auto", quantization: Optional[str] = None, modelopt_quant: Optional[Union[str, Dict]] = None, + modelopt_checkpoint_restore_path: Optional[str] = None, + modelopt_checkpoint_save_path: Optional[str] = None, override_config_file: Optional[str] = None, is_draft_model: bool = False, hybrid_kvcache_ratio: Optional[float] = None, diff --git a/python/sglang/srt/model_loader/loader.py b/python/sglang/srt/model_loader/loader.py index 8b6676141..de58a8dd7 100644 --- a/python/sglang/srt/model_loader/loader.py +++ b/python/sglang/srt/model_loader/loader.py @@ -18,7 +18,7 @@ import threading import time from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor -from contextlib import contextmanager +from contextlib import contextmanager, suppress from typing import ( TYPE_CHECKING, Any, @@ -30,7 +30,6 @@ from typing import ( Tuple, cast, ) -from urllib.parse import urlparse import huggingface_hub import numpy as np @@ -52,7 +51,7 @@ except ImportError: from huggingface_hub import HfApi, hf_hub_download from torch import nn -from transformers import AutoConfig, AutoModelForCausalLM +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from sglang.srt.configs.load_config import LoadConfig, LoadFormat @@ -104,6 +103,7 @@ from sglang.srt.utils import ( get_device_capability, is_npu, is_pin_memory_available, + rank0_log, set_weight_attrs, ) @@ -545,7 +545,7 @@ class DefaultModelLoader(BaseModelLoader): **model_kwargs, trust_remote_code=True, ) - logger.info(f"ModelOpt quantization requested: {model_config.modelopt_quant}") + rank0_log(f"ModelOpt quantization requested: {model_config.modelopt_quant}") quant_choice_str = model_config.modelopt_quant if not isinstance(quant_choice_str, str): @@ -1764,6 +1764,96 @@ class ModelOptModelLoader(DefaultModelLoader): super().__init__(load_config) # Any ModelOpt specific initialization if needed + def _setup_modelopt_quantization( + self, + model, + tokenizer, + quant_cfg, + quantized_ckpt_restore_path: str | None = None, + quantized_ckpt_save_path: str | None = None, + ) -> None: + """ + Set up ModelOpt quantization for the given model. + + Args: + model: The model to quantize + tokenizer: The tokenizer associated with the model + quant_cfg: The quantization configuration + quantized_ckpt_restore_path: Path to restore quantized checkpoint from + quantized_ckpt_save_path: Path to save quantized checkpoint to + + Raises: + ImportError: If ModelOpt is not available + Exception: If quantization setup fails + """ + try: + import modelopt.torch.opt as mto + import modelopt.torch.quantization as mtq + from modelopt.torch.quantization.utils import is_quantized + except ImportError as e: + raise ImportError( + "ModelOpt is not available. Please install modelopt." + ) from e + + if is_quantized(model): + rank0_log("Model is already quantized, skipping quantization setup.") + return + # Restore from checkpoint if provided + if quantized_ckpt_restore_path: + try: + mto.restore(model, quantized_ckpt_restore_path) + rank0_log( + f"Restored quantized model from {quantized_ckpt_restore_path}" + ) + return + except Exception as e: + logger.warning( + f"Failed to restore from {quantized_ckpt_restore_path}: {e}" + ) + rank0_log("Proceeding with calibration-based quantization...") + + # Set up calibration-based quantization + try: + # Left padding tends to work better for batched generation with decoder-only LMs + with suppress(Exception): + tokenizer.padding_side = "left" + + from modelopt.torch.utils.dataset_utils import ( + create_forward_loop, + get_dataset_dataloader, + ) + + # Create calibration dataloader + calib_dataloader = get_dataset_dataloader( + dataset_name="cnn_dailymail", # TODO: Consider making this configurable + tokenizer=tokenizer, + batch_size=36, # TODO: Consider making this configurable + num_samples=512, # TODO: Consider making this configurable + device=model.device, + include_labels=False, + ) + + calibrate_loop = create_forward_loop(dataloader=calib_dataloader) + + # Apply quantization + mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) + + if get_tensor_model_parallel_rank() == 0: + mtq.print_quant_summary(model) + + # Save checkpoint if path provided + if quantized_ckpt_save_path: + try: + mto.save(model, quantized_ckpt_save_path) + rank0_log(f"Quantized model saved to {quantized_ckpt_save_path}") + except Exception as e: + logger.warning( + f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}" + ) + + except Exception as e: + raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e + def load_model( self, *, @@ -1779,7 +1869,6 @@ class ModelOptModelLoader(DefaultModelLoader): # Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization) try: import modelopt.torch.quantization as mtq - from modelopt.torch.utils.dataset_utils import create_forward_loop except ImportError: logger.error( "NVIDIA Model Optimizer (modelopt) library not found. " @@ -1808,33 +1897,26 @@ class ModelOptModelLoader(DefaultModelLoader): "Please verify QUANT_CFG_CHOICES and the ModelOpt library." ) - # For now, assume no calibration. Calibration setup is a separate, more complex step. - use_calibration = False # This would ideally be a configurable parameter - calib_dataloader = None # This would need to be provided/configured - - calibrate_loop = ( - create_forward_loop(dataloader=calib_dataloader) - if use_calibration - else None - ) - - if use_calibration and calib_dataloader is None: - logger.warning( - "ModelOpt calibration requested but no calib_dataloader provided. " - "Proceeding without calibration. Quantization accuracy may be affected." - ) - logger.info( f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}" ) + quantized_ckpt_restore_path = model_config.modelopt_checkpoint_restore_path + quantized_ckpt_save_path = model_config.modelopt_checkpoint_save_path + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_path, use_fast=True + ) try: - model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop) - logger.info("Model successfully quantized with ModelOpt.") + self._setup_modelopt_quantization( + model, + tokenizer, + quant_cfg, + quantized_ckpt_restore_path=quantized_ckpt_restore_path, + quantized_ckpt_save_path=quantized_ckpt_save_path, + ) except Exception as e: - logger.error(f"Error during ModelOpt mtq.quantize call: {e}") - raise - mtq.print_quant_summary(model) + logger.warning(f"ModelOpt quantization failed: {e}") + rank0_log("Proceeding without quantization...") return model.eval() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 3b205ef8f..abb850160 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -178,6 +178,8 @@ class ServerArgs: model_loader_extra_config: str = "{}" trust_remote_code: bool = False modelopt_quant: Optional[Union[str, Dict]] = None + modelopt_checkpoint_restore_path: Optional[str] = None + modelopt_checkpoint_save_path: Optional[str] = None context_length: Optional[int] = None is_embedding: bool = False enable_multimodal: Optional[bool] = None @@ -1504,6 +1506,21 @@ class ServerArgs: "Supported values: 'fp8', 'int4_awq', 'w4a8_awq', 'nvfp4', 'nvfp4_awq'. " "This requires the NVIDIA Model Optimizer library to be installed: pip install nvidia-modelopt", ) + parser.add_argument( + "--modelopt-checkpoint-restore-path", + type=str, + default=ServerArgs.modelopt_checkpoint_restore_path, + help="Path to restore a previously saved ModelOpt quantized checkpoint. " + "If provided, the quantization process will be skipped and the model " + "will be loaded from this checkpoint.", + ) + parser.add_argument( + "--modelopt-checkpoint-save-path", + type=str, + default=ServerArgs.modelopt_checkpoint_save_path, + help="Path to save the ModelOpt quantized checkpoint after quantization. " + "This allows reusing the quantized model in future runs.", + ) parser.add_argument( "--kv-cache-dtype", type=str,