# SPDX-License-Identifier: Apache-2.0 import ast import copy import enum import hashlib import inspect import json import re import sys import textwrap import warnings from collections import Counter from contextlib import contextmanager from dataclasses import (MISSING, dataclass, field, fields, is_dataclass, replace) from importlib.util import find_spec from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, Optional, Protocol, TypeVar, Union, get_args) import torch from pydantic import BaseModel, Field, PrivateAttr from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig import vllm.envs as envs from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, QuantizationMethods, get_quantization_config) from vllm.model_executor.models import ModelRegistry from vllm.platforms import CpuArchEnum, current_platform from vllm.config.model import _STR_DTYPE_TO_TORCH_DTYPE logger = init_logger(__name__) def ModelConfig___verify_quantization(self) -> None: supported_quantization = QUANTIZATION_METHODS optimized_quantization_methods = [ "fp8", "modelopt", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8", "quark", "modelopt_fp4", "bitblas"#, "gptq_bitblas" ] if self.quantization is not None: self.quantization = self.quantization.lower() # Parse quantization method from the HF model config, if available. quant_cfg = self._parse_quant_hf_config(self.hf_config) if quant_cfg is None and (text_config := getattr( self.hf_config, "text_config", None)): # Check the text config as well for multi-modal models. quant_cfg = self._parse_quant_hf_config(text_config) if quant_cfg is not None: quant_method = quant_cfg.get("quant_method", "").lower() quant_method = quant_method.replace("compressed_tensors", "compressed-tensors") quant_cfg["quant_method"] = quant_method # Quantization methods which are overrides (i.e. they have a # `override_quantization_method` method) must be checked in order # of preference (this is particularly important for GPTQ). overrides = [ # "marlin", "bitblas", "gptq_marlin_24", "gptq_marlin", # "gptq_bitblas", "awq_marlin", "ipex", "moe_wna16", "modelopt", "modelopt_fp4", "petit_nvfp4", ] quantization_methods = [ q for q in supported_quantization if q not in overrides ] # Any custom overrides will be in quantization_methods so we place # them at the start of the list so custom overrides have preference # over the built in ones. quantization_methods = quantization_methods + overrides # Detect which checkpoint is it for name in quantization_methods: method = get_quantization_config(name) quantization_override = method.override_quantization_method( quant_cfg, self.quantization) if quantization_override is not None: # Raise error if the override is not custom (custom would # be in QUANTIZATION_METHODS but not QuantizationMethods) # and hasn't been added to the overrides list. if (name in get_args(QuantizationMethods) and name not in overrides): raise ValueError( f"Quantization method {name} is an override but " "is has not been added to the `overrides` list " "above. This is necessary to ensure that the " "overrides are checked in order of preference.") quant_method = quantization_override self.quantization = quantization_override break # Verify quantization configurations. if self.quantization is None: self.quantization = quant_method elif self.quantization != quant_method: raise ValueError( "Quantization method specified in the model config " f"({quant_method}) does not match the quantization " f"method specified in the `quantization` argument " f"({self.quantization}).") if self.quantization is not None: if self.quantization not in supported_quantization: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " f"be one of {supported_quantization}.") from vllm.platforms import current_platform current_platform.verify_quantization(self.quantization) if self.quantization not in optimized_quantization_methods: logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " "non-quantized models.", self.quantization) def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype, runner_type: str) -> torch.dtype: head_dtype: Optional[Union[str, torch.dtype]] = getattr(config, "head_dtype", None) if head_dtype == "model": return dtype elif isinstance(head_dtype, str): head_dtype = head_dtype.lower() if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE: raise ValueError(f"Unknown dtype: {head_dtype!r}") return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype] elif isinstance(head_dtype, torch.dtype): return head_dtype elif head_dtype is None: if torch.float32 not in current_platform.supported_dtypes: return dtype if runner_type == "pooling": return torch.float16 return dtype else: raise ValueError(f"Unknown dtype: {head_dtype}")