Files

155 lines
6.2 KiB
Python
Raw Permalink Normal View History

2026-04-02 04:53:13 +00:00
# SPDX-License-Identifier: Apache-2.0
import ast
import copy
import enum
import hashlib
import inspect
import json
import re
import sys
import textwrap
import warnings
from collections import Counter
from contextlib import contextmanager
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace)
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, TypeVar, Union, get_args)
import torch
from pydantic import BaseModel, Field, PrivateAttr
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
QuantizationMethods,
get_quantization_config)
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import CpuArchEnum, current_platform
from vllm.config.model import _STR_DTYPE_TO_TORCH_DTYPE
logger = init_logger(__name__)
def ModelConfig___verify_quantization(self) -> None:
supported_quantization = QUANTIZATION_METHODS
optimized_quantization_methods = [
"fp8", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
"quark", "modelopt_fp4", "bitblas"#, "gptq_bitblas"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config(self.hf_config)
if quant_cfg is None and (text_config := getattr(
self.hf_config, "text_config", None)):
# Check the text config as well for multi-modal models.
quant_cfg = self._parse_quant_hf_config(text_config)
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
quant_method = quant_method.replace("compressed_tensors",
"compressed-tensors")
quant_cfg["quant_method"] = quant_method
# Quantization methods which are overrides (i.e. they have a
# `override_quantization_method` method) must be checked in order
# of preference (this is particularly important for GPTQ).
overrides = [
# "marlin",
"bitblas",
"gptq_marlin_24",
"gptq_marlin",
# "gptq_bitblas",
"awq_marlin",
"ipex",
"moe_wna16",
"modelopt",
"modelopt_fp4",
"petit_nvfp4",
]
quantization_methods = [
q for q in supported_quantization if q not in overrides
]
# Any custom overrides will be in quantization_methods so we place
# them at the start of the list so custom overrides have preference
# over the built in ones.
quantization_methods = quantization_methods + overrides
# Detect which checkpoint is it
for name in quantization_methods:
method = get_quantization_config(name)
quantization_override = method.override_quantization_method(
quant_cfg, self.quantization)
if quantization_override is not None:
# Raise error if the override is not custom (custom would
# be in QUANTIZATION_METHODS but not QuantizationMethods)
# and hasn't been added to the overrides list.
if (name in get_args(QuantizationMethods)
and name not in overrides):
raise ValueError(
f"Quantization method {name} is an override but "
"is has not been added to the `overrides` list "
"above. This is necessary to ensure that the "
"overrides are checked in order of preference.")
quant_method = quantization_override
self.quantization = quantization_override
break
# Verify quantization configurations.
if self.quantization is None:
self.quantization = quant_method
elif self.quantization != quant_method:
raise ValueError(
"Quantization method specified in the model config "
f"({quant_method}) does not match the quantization "
f"method specified in the `quantization` argument "
f"({self.quantization}).")
if self.quantization is not None:
if self.quantization not in supported_quantization:
raise ValueError(
f"Unknown quantization method: {self.quantization}. Must "
f"be one of {supported_quantization}.")
from vllm.platforms import current_platform
current_platform.verify_quantization(self.quantization)
if self.quantization not in optimized_quantization_methods:
logger.warning(
"%s quantization is not fully "
"optimized yet. The speed can be slower than "
"non-quantized models.", self.quantization)
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
runner_type: str) -> torch.dtype:
head_dtype: Optional[Union[str,
torch.dtype]] = getattr(config, "head_dtype",
None)
if head_dtype == "model":
return dtype
elif isinstance(head_dtype, str):
head_dtype = head_dtype.lower()
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
raise ValueError(f"Unknown dtype: {head_dtype!r}")
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
elif isinstance(head_dtype, torch.dtype):
return head_dtype
elif head_dtype is None:
if torch.float32 not in current_platform.supported_dtypes:
return dtype
if runner_type == "pooling":
return torch.float16
return dtype
else:
raise ValueError(f"Unknown dtype: {head_dtype}")