Enable native ModelOpt quantization support (1/3) (#7149)
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
@@ -17,7 +17,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
from enum import Enum, IntEnum, auto
|
||||
from typing import List, Optional, Set, Union
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@@ -85,6 +85,7 @@ class ModelConfig:
|
||||
enable_multimodal: Optional[bool] = None,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
modelopt_quant: Optional[Union[str, Dict]] = None,
|
||||
override_config_file: Optional[str] = None,
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
@@ -94,6 +95,7 @@ class ModelConfig:
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
self.modelopt_quant = modelopt_quant
|
||||
self.is_draft_model = is_draft_model
|
||||
self.model_impl = model_impl
|
||||
|
||||
@@ -209,6 +211,7 @@ class ModelConfig:
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
modelopt_quant=server_args.modelopt_quant,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
**kwargs,
|
||||
@@ -477,54 +480,52 @@ class ModelConfig:
|
||||
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
|
||||
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
|
||||
is_local = os.path.exists(self.model_path)
|
||||
modelopt_quant_config = {"quant_method": "modelopt"}
|
||||
if not is_local:
|
||||
import huggingface_hub
|
||||
|
||||
try:
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
|
||||
hf_api = HfApi()
|
||||
|
||||
def check_hf_quant_config():
|
||||
return hf_api.file_exists(
|
||||
self.model_path, "hf_quant_config.json"
|
||||
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
||||
# Download and parse the quantization config for remote models
|
||||
quant_config_file = hf_hub_download(
|
||||
repo_id=self.model_path,
|
||||
filename="hf_quant_config.json",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
# Retry HF API call up to 3 times
|
||||
file_exists = retry(
|
||||
check_hf_quant_config,
|
||||
max_retry=2,
|
||||
initial_delay=1.0,
|
||||
max_delay=5.0,
|
||||
)
|
||||
|
||||
if file_exists:
|
||||
quant_cfg = modelopt_quant_config
|
||||
|
||||
with open(quant_config_file) as f:
|
||||
quant_config_dict = json.load(f)
|
||||
quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
|
||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||
logger.warning(
|
||||
"Offline mode is enabled, skipping hf_quant_config.json check"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to check hf_quant_config.json: {self.model_path} {e}"
|
||||
)
|
||||
|
||||
pass
|
||||
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
||||
quant_config_file = os.path.join(
|
||||
self.model_path, "hf_quant_config.json"
|
||||
)
|
||||
with open(quant_config_file) as f:
|
||||
quant_config_dict = json.load(f)
|
||||
json_quant_configs = quant_config_dict["quantization"]
|
||||
quant_algo = json_quant_configs.get("quant_algo", None)
|
||||
if quant_algo == "MIXED_PRECISION":
|
||||
quant_cfg = {"quant_method": "w4afp8"}
|
||||
else:
|
||||
quant_cfg = modelopt_quant_config
|
||||
quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
|
||||
return quant_cfg
|
||||
|
||||
def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> dict:
|
||||
"""Parse ModelOpt quantization config and return the appropriate quant_method."""
|
||||
json_quant_configs = quant_config_dict["quantization"]
|
||||
quant_algo = json_quant_configs.get("quant_algo", None)
|
||||
|
||||
if quant_algo == "MIXED_PRECISION":
|
||||
return {"quant_method": "w4afp8"}
|
||||
elif quant_algo and ("FP4" in quant_algo or "NVFP4" in quant_algo):
|
||||
return {"quant_method": "modelopt_fp4"}
|
||||
elif quant_algo and "FP8" in quant_algo:
|
||||
return {"quant_method": "modelopt_fp8"}
|
||||
else:
|
||||
# Default to FP8 for backward compatibility
|
||||
return {"quant_method": "modelopt_fp8"}
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
@@ -543,7 +544,8 @@ class ModelConfig:
|
||||
optimized_quantization_methods = [
|
||||
"fp8",
|
||||
"marlin",
|
||||
"modelopt",
|
||||
"modelopt_fp8",
|
||||
"modelopt_fp4",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
|
||||
11
python/sglang/srt/layers/modelopt_utils.py
Normal file
11
python/sglang/srt/layers/modelopt_utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
"""
|
||||
ModelOpt related constants
|
||||
"""
|
||||
|
||||
QUANT_CFG_CHOICES = {
|
||||
"fp8": "FP8_DEFAULT_CFG",
|
||||
"int4_awq": "INT4_AWQ_CFG", # TODO: add support for int4_awq
|
||||
"w4a8_awq": "W4A8_AWQ_BETA_CFG", # TODO: add support for w4a8_awq
|
||||
"nvfp4": "NVFP4_DEFAULT_CFG",
|
||||
"nvfp4_awq": "NVFP4_AWQ_LITE_CFG", # TODO: add support for nvfp4_awq
|
||||
}
|
||||
@@ -72,7 +72,7 @@ if TYPE_CHECKING:
|
||||
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"fp8": Fp8Config,
|
||||
"blockwise_int8": BlockInt8Config,
|
||||
"modelopt": ModelOptFp8Config,
|
||||
"modelopt_fp8": ModelOptFp8Config,
|
||||
"modelopt_fp4": ModelOptFp4Config,
|
||||
"w8a8_int8": W8A8Int8Config,
|
||||
"w8a8_fp8": W8A8Fp8Config,
|
||||
|
||||
@@ -113,7 +113,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "modelopt"
|
||||
return "modelopt_fp8"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
|
||||
@@ -880,7 +880,7 @@ class ModelRunner:
|
||||
load_config = LoadConfig(load_format=load_format)
|
||||
|
||||
# Only support DefaultModelLoader for now
|
||||
loader = get_model_loader(load_config)
|
||||
loader = get_model_loader(load_config, self.model_config)
|
||||
if not isinstance(loader, DefaultModelLoader):
|
||||
message = f"Failed to get model loader: {loader}."
|
||||
return False, message
|
||||
|
||||
@@ -24,7 +24,7 @@ def get_model(
|
||||
load_config: LoadConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
loader = get_model_loader(load_config)
|
||||
loader = get_model_loader(load_config, model_config)
|
||||
return loader.load_model(
|
||||
model_config=model_config,
|
||||
device_config=device_config,
|
||||
|
||||
@@ -37,10 +37,22 @@ import numpy as np
|
||||
import requests
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
# Try to import accelerate (optional dependency)
|
||||
try:
|
||||
from accelerate import infer_auto_device_map, init_empty_weights
|
||||
from accelerate.utils import get_max_memory
|
||||
|
||||
HAS_ACCELERATE = True
|
||||
except ImportError:
|
||||
HAS_ACCELERATE = False
|
||||
infer_auto_device_map = None
|
||||
init_empty_weights = None
|
||||
get_max_memory = None
|
||||
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
from torch import nn
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoModelForCausalLM
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
||||
@@ -54,6 +66,8 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
||||
trigger_transferring_weights_request,
|
||||
)
|
||||
@@ -62,6 +76,11 @@ from sglang.srt.model_loader.utils import (
|
||||
post_load_weights,
|
||||
set_default_torch_dtype,
|
||||
)
|
||||
|
||||
# Constants for memory management
|
||||
DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION = (
|
||||
0.8 # Reserve 20% GPU memory headroom for ModelOpt calibration
|
||||
)
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
_BAR_FORMAT,
|
||||
default_weight_loader,
|
||||
@@ -94,6 +113,8 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
|
||||
_is_npu = is_npu()
|
||||
# ModelOpt: QUANT_CFG_CHOICES is imported from modelopt_utils.py
|
||||
# which contains the complete mapping of quantization config choices
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -477,12 +498,78 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
model_config.model_path, model_config.revision, fall_back_to_pt=True
|
||||
)
|
||||
|
||||
def _load_modelopt_base_model(self, model_config: ModelConfig) -> nn.Module:
|
||||
"""Load and prepare the base model for ModelOpt quantization.
|
||||
|
||||
This method handles the common model loading logic shared between
|
||||
DefaultModelLoader (conditional) and ModelOptModelLoader (dedicated).
|
||||
"""
|
||||
if not HAS_ACCELERATE:
|
||||
raise ImportError(
|
||||
"accelerate is required for ModelOpt quantization. "
|
||||
"Please install it with: pip install accelerate"
|
||||
)
|
||||
|
||||
hf_config = AutoConfig.from_pretrained(
|
||||
model_config.model_path, trust_remote_code=True
|
||||
)
|
||||
with init_empty_weights():
|
||||
torch_dtype = getattr(hf_config, "torch_dtype", torch.float16)
|
||||
model = AutoModelForCausalLM.from_config(
|
||||
hf_config, torch_dtype=torch_dtype, trust_remote_code=True
|
||||
)
|
||||
max_memory = get_max_memory()
|
||||
inferred_device_map = infer_auto_device_map(model, max_memory=max_memory)
|
||||
|
||||
on_cpu = "cpu" in inferred_device_map.values()
|
||||
model_kwargs = {"torch_dtype": "auto"}
|
||||
device_map = "auto"
|
||||
|
||||
if on_cpu:
|
||||
for device in max_memory.keys():
|
||||
if isinstance(device, int):
|
||||
max_memory[device] *= DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION
|
||||
|
||||
logger.warning(
|
||||
"Model does not fit to the GPU mem. "
|
||||
f"We apply the following memory limit for calibration: \n{max_memory}\n"
|
||||
f"If you hit GPU OOM issue, please adjust the memory fraction "
|
||||
f"(currently {DEFAULT_GPU_MEMORY_FRACTION_FOR_CALIBRATION}) or "
|
||||
"reduce the calibration `batch_size` manually."
|
||||
)
|
||||
model_kwargs["max_memory"] = max_memory
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_config.model_path,
|
||||
device_map=device_map,
|
||||
**model_kwargs,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
logger.info(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
|
||||
|
||||
quant_choice_str = model_config.modelopt_quant
|
||||
if not isinstance(quant_choice_str, str):
|
||||
raise TypeError(
|
||||
f"modelopt_quant must be a string preset key (e.g., 'fp8'), "
|
||||
f"got {type(quant_choice_str)}"
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
|
||||
if hasattr(model_config, "modelopt_quant") and model_config.modelopt_quant:
|
||||
# Load base model using shared method
|
||||
model = self._load_modelopt_base_model(model_config)
|
||||
# Note: DefaultModelLoader doesn't do additional quantization processing
|
||||
# For full ModelOpt quantization, use ModelOptModelLoader
|
||||
return model.eval()
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
@@ -491,9 +578,9 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self.load_config,
|
||||
)
|
||||
|
||||
self.load_weights_and_postprocess(
|
||||
model, self._get_all_weights(model_config, model), target_device
|
||||
)
|
||||
self.load_weights_and_postprocess(
|
||||
model, self._get_all_weights(model_config, model), target_device
|
||||
)
|
||||
|
||||
return model.eval()
|
||||
|
||||
@@ -1668,9 +1755,103 @@ def load_model_with_cpu_quantization(
|
||||
return model.eval()
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
class ModelOptModelLoader(DefaultModelLoader):
|
||||
"""
|
||||
Model loader that applies NVIDIA Model Optimizer quantization
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
# Any ModelOpt specific initialization if needed
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
*,
|
||||
model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
) -> nn.Module:
|
||||
|
||||
logger.info("ModelOptModelLoader: Loading base model...")
|
||||
|
||||
# Use shared method from parent class to load base model
|
||||
model = self._load_modelopt_base_model(model_config)
|
||||
|
||||
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
|
||||
try:
|
||||
import modelopt.torch.quantization as mtq
|
||||
from modelopt.torch.utils.dataset_utils import create_forward_loop
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"NVIDIA Model Optimizer (modelopt) library not found. "
|
||||
"Please install it to use 'modelopt_quant' feature."
|
||||
)
|
||||
raise
|
||||
|
||||
quant_choice_str = model_config.modelopt_quant
|
||||
|
||||
quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str)
|
||||
if not quant_cfg_name:
|
||||
raise ValueError(
|
||||
f"Invalid modelopt_quant choice: '{quant_choice_str}'. "
|
||||
f"Available choices in QUANT_CFG_CHOICES: {list(QUANT_CFG_CHOICES.keys())}. "
|
||||
"Ensure QUANT_CFG_CHOICES is correctly defined with mappings to "
|
||||
"attribute names of config objects in modelopt.torch.quantization."
|
||||
)
|
||||
|
||||
try:
|
||||
# getattr will fetch the config object, e.g., mtq.FP8_DEFAULT_CFG
|
||||
quant_cfg = getattr(mtq, quant_cfg_name)
|
||||
except AttributeError:
|
||||
raise AttributeError(
|
||||
f"ModelOpt quantization config attribute '{quant_cfg_name}' "
|
||||
f"(from choice '{quant_choice_str}') not found in modelopt.torch.quantization module. "
|
||||
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
|
||||
)
|
||||
|
||||
# For now, assume no calibration. Calibration setup is a separate, more complex step.
|
||||
use_calibration = False # This would ideally be a configurable parameter
|
||||
calib_dataloader = None # This would need to be provided/configured
|
||||
|
||||
calibrate_loop = (
|
||||
create_forward_loop(dataloader=calib_dataloader)
|
||||
if use_calibration
|
||||
else None
|
||||
)
|
||||
|
||||
if use_calibration and calib_dataloader is None:
|
||||
logger.warning(
|
||||
"ModelOpt calibration requested but no calib_dataloader provided. "
|
||||
"Proceeding without calibration. Quantization accuracy may be affected."
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}"
|
||||
)
|
||||
|
||||
try:
|
||||
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
|
||||
logger.info("Model successfully quantized with ModelOpt.")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during ModelOpt mtq.quantize call: {e}")
|
||||
raise
|
||||
mtq.print_quant_summary(model)
|
||||
|
||||
return model.eval()
|
||||
|
||||
|
||||
def get_model_loader(
|
||||
load_config: LoadConfig, model_config: Optional[ModelConfig] = None
|
||||
) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
|
||||
if (
|
||||
model_config
|
||||
and hasattr(model_config, "modelopt_quant")
|
||||
and model_config.modelopt_quant
|
||||
):
|
||||
logger.info("Using ModelOptModelLoader due to 'modelopt_quant' config.")
|
||||
return ModelOptModelLoader(load_config)
|
||||
|
||||
if isinstance(load_config.load_format, type):
|
||||
return load_config.load_format(load_config)
|
||||
|
||||
|
||||
@@ -226,6 +226,9 @@ def get_quant_config(
|
||||
return ModelOptFp4Config.from_config(config)
|
||||
else:
|
||||
return quant_cls.from_config(config)
|
||||
elif model_config.quantization == "modelopt_fp8":
|
||||
if config["producer"]["name"] == "modelopt_fp8":
|
||||
return quant_cls.from_config(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization config"
|
||||
|
||||
@@ -20,7 +20,7 @@ import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from sglang.srt.connector import ConnectorType
|
||||
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
||||
@@ -162,6 +162,7 @@ class ServerArgs:
|
||||
load_format: str = "auto"
|
||||
model_loader_extra_config: str = "{}"
|
||||
trust_remote_code: bool = False
|
||||
modelopt_quant: Optional[Union[str, Dict]] = None
|
||||
context_length: Optional[int] = None
|
||||
is_embedding: bool = False
|
||||
enable_multimodal: Optional[bool] = None
|
||||
@@ -1455,6 +1456,14 @@ class ServerArgs:
|
||||
"KV cache dtype is FP8. Otherwise, KV cache scaling factors "
|
||||
"default to 1.0, which may cause accuracy issues. ",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--modelopt-quant",
|
||||
type=str,
|
||||
default=ServerArgs.modelopt_quant,
|
||||
help="The ModelOpt quantization configuration. "
|
||||
"Supported values: 'fp8', 'int4_awq', 'w4a8_awq', 'nvfp4', 'nvfp4_awq'. "
|
||||
"This requires the NVIDIA Model Optimizer library to be installed: pip install nvidia-modelopt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
|
||||
Reference in New Issue
Block a user