Enable native ModelOpt quantization support (1/3) (#7149)
Signed-off-by: Zhiyu Cheng <zhiyuc@nvidia.com>
This commit is contained in:
@@ -17,7 +17,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
from enum import Enum, IntEnum, auto
|
||||
from typing import List, Optional, Set, Union
|
||||
from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@@ -85,6 +85,7 @@ class ModelConfig:
|
||||
enable_multimodal: Optional[bool] = None,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
modelopt_quant: Optional[Union[str, Dict]] = None,
|
||||
override_config_file: Optional[str] = None,
|
||||
is_draft_model: bool = False,
|
||||
hybrid_kvcache_ratio: Optional[float] = None,
|
||||
@@ -94,6 +95,7 @@ class ModelConfig:
|
||||
self.model_path = model_path
|
||||
self.revision = revision
|
||||
self.quantization = quantization
|
||||
self.modelopt_quant = modelopt_quant
|
||||
self.is_draft_model = is_draft_model
|
||||
self.model_impl = model_impl
|
||||
|
||||
@@ -209,6 +211,7 @@ class ModelConfig:
|
||||
enable_multimodal=server_args.enable_multimodal,
|
||||
dtype=server_args.dtype,
|
||||
quantization=server_args.quantization,
|
||||
modelopt_quant=server_args.modelopt_quant,
|
||||
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
||||
model_impl=server_args.model_impl,
|
||||
**kwargs,
|
||||
@@ -477,54 +480,52 @@ class ModelConfig:
|
||||
# example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
|
||||
# example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
|
||||
is_local = os.path.exists(self.model_path)
|
||||
modelopt_quant_config = {"quant_method": "modelopt"}
|
||||
if not is_local:
|
||||
import huggingface_hub
|
||||
|
||||
try:
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
|
||||
hf_api = HfApi()
|
||||
|
||||
def check_hf_quant_config():
|
||||
return hf_api.file_exists(
|
||||
self.model_path, "hf_quant_config.json"
|
||||
if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
|
||||
# Download and parse the quantization config for remote models
|
||||
quant_config_file = hf_hub_download(
|
||||
repo_id=self.model_path,
|
||||
filename="hf_quant_config.json",
|
||||
revision=self.revision,
|
||||
)
|
||||
|
||||
# Retry HF API call up to 3 times
|
||||
file_exists = retry(
|
||||
check_hf_quant_config,
|
||||
max_retry=2,
|
||||
initial_delay=1.0,
|
||||
max_delay=5.0,
|
||||
)
|
||||
|
||||
if file_exists:
|
||||
quant_cfg = modelopt_quant_config
|
||||
|
||||
with open(quant_config_file) as f:
|
||||
quant_config_dict = json.load(f)
|
||||
quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
|
||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||
logger.warning(
|
||||
"Offline mode is enabled, skipping hf_quant_config.json check"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to check hf_quant_config.json: {self.model_path} {e}"
|
||||
)
|
||||
|
||||
pass
|
||||
elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
|
||||
quant_config_file = os.path.join(
|
||||
self.model_path, "hf_quant_config.json"
|
||||
)
|
||||
with open(quant_config_file) as f:
|
||||
quant_config_dict = json.load(f)
|
||||
json_quant_configs = quant_config_dict["quantization"]
|
||||
quant_algo = json_quant_configs.get("quant_algo", None)
|
||||
if quant_algo == "MIXED_PRECISION":
|
||||
quant_cfg = {"quant_method": "w4afp8"}
|
||||
else:
|
||||
quant_cfg = modelopt_quant_config
|
||||
quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
|
||||
return quant_cfg
|
||||
|
||||
def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> dict:
|
||||
"""Parse ModelOpt quantization config and return the appropriate quant_method."""
|
||||
json_quant_configs = quant_config_dict["quantization"]
|
||||
quant_algo = json_quant_configs.get("quant_algo", None)
|
||||
|
||||
if quant_algo == "MIXED_PRECISION":
|
||||
return {"quant_method": "w4afp8"}
|
||||
elif quant_algo and ("FP4" in quant_algo or "NVFP4" in quant_algo):
|
||||
return {"quant_method": "modelopt_fp4"}
|
||||
elif quant_algo and "FP8" in quant_algo:
|
||||
return {"quant_method": "modelopt_fp8"}
|
||||
else:
|
||||
# Default to FP8 for backward compatibility
|
||||
return {"quant_method": "modelopt_fp8"}
|
||||
|
||||
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = [*QUANTIZATION_METHODS]
|
||||
@@ -543,7 +544,8 @@ class ModelConfig:
|
||||
optimized_quantization_methods = [
|
||||
"fp8",
|
||||
"marlin",
|
||||
"modelopt",
|
||||
"modelopt_fp8",
|
||||
"modelopt_fp4",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
|
||||
Reference in New Issue
Block a user