155 lines
6.2 KiB
Python
155 lines
6.2 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
|
||
|
|
import ast
|
||
|
|
import copy
|
||
|
|
import enum
|
||
|
|
import hashlib
|
||
|
|
import inspect
|
||
|
|
import json
|
||
|
|
import re
|
||
|
|
import sys
|
||
|
|
import textwrap
|
||
|
|
import warnings
|
||
|
|
from collections import Counter
|
||
|
|
from contextlib import contextmanager
|
||
|
|
from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||
|
|
replace)
|
||
|
|
from importlib.util import find_spec
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||
|
|
Optional, Protocol, TypeVar, Union, get_args)
|
||
|
|
|
||
|
|
import torch
|
||
|
|
from pydantic import BaseModel, Field, PrivateAttr
|
||
|
|
from torch.distributed import ProcessGroup, ReduceOp
|
||
|
|
from transformers import PretrainedConfig
|
||
|
|
|
||
|
|
import vllm.envs as envs
|
||
|
|
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||
|
|
QuantizationMethods,
|
||
|
|
get_quantization_config)
|
||
|
|
from vllm.model_executor.models import ModelRegistry
|
||
|
|
from vllm.platforms import CpuArchEnum, current_platform
|
||
|
|
from vllm.config.model import _STR_DTYPE_TO_TORCH_DTYPE
|
||
|
|
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
def ModelConfig___verify_quantization(self) -> None:
|
||
|
|
supported_quantization = QUANTIZATION_METHODS
|
||
|
|
optimized_quantization_methods = [
|
||
|
|
"fp8", "modelopt", "gptq_marlin_24", "gptq_marlin",
|
||
|
|
"awq_marlin", "fbgemm_fp8", "compressed-tensors", "experts_int8",
|
||
|
|
"quark", "modelopt_fp4", "bitblas"#, "gptq_bitblas"
|
||
|
|
]
|
||
|
|
if self.quantization is not None:
|
||
|
|
self.quantization = self.quantization.lower()
|
||
|
|
|
||
|
|
# Parse quantization method from the HF model config, if available.
|
||
|
|
quant_cfg = self._parse_quant_hf_config(self.hf_config)
|
||
|
|
if quant_cfg is None and (text_config := getattr(
|
||
|
|
self.hf_config, "text_config", None)):
|
||
|
|
# Check the text config as well for multi-modal models.
|
||
|
|
quant_cfg = self._parse_quant_hf_config(text_config)
|
||
|
|
|
||
|
|
if quant_cfg is not None:
|
||
|
|
quant_method = quant_cfg.get("quant_method", "").lower()
|
||
|
|
quant_method = quant_method.replace("compressed_tensors",
|
||
|
|
"compressed-tensors")
|
||
|
|
quant_cfg["quant_method"] = quant_method
|
||
|
|
|
||
|
|
# Quantization methods which are overrides (i.e. they have a
|
||
|
|
# `override_quantization_method` method) must be checked in order
|
||
|
|
# of preference (this is particularly important for GPTQ).
|
||
|
|
overrides = [
|
||
|
|
# "marlin",
|
||
|
|
"bitblas",
|
||
|
|
"gptq_marlin_24",
|
||
|
|
"gptq_marlin",
|
||
|
|
# "gptq_bitblas",
|
||
|
|
"awq_marlin",
|
||
|
|
"ipex",
|
||
|
|
"moe_wna16",
|
||
|
|
"modelopt",
|
||
|
|
"modelopt_fp4",
|
||
|
|
"petit_nvfp4",
|
||
|
|
]
|
||
|
|
quantization_methods = [
|
||
|
|
q for q in supported_quantization if q not in overrides
|
||
|
|
]
|
||
|
|
# Any custom overrides will be in quantization_methods so we place
|
||
|
|
# them at the start of the list so custom overrides have preference
|
||
|
|
# over the built in ones.
|
||
|
|
quantization_methods = quantization_methods + overrides
|
||
|
|
|
||
|
|
# Detect which checkpoint is it
|
||
|
|
for name in quantization_methods:
|
||
|
|
method = get_quantization_config(name)
|
||
|
|
quantization_override = method.override_quantization_method(
|
||
|
|
quant_cfg, self.quantization)
|
||
|
|
if quantization_override is not None:
|
||
|
|
# Raise error if the override is not custom (custom would
|
||
|
|
# be in QUANTIZATION_METHODS but not QuantizationMethods)
|
||
|
|
# and hasn't been added to the overrides list.
|
||
|
|
if (name in get_args(QuantizationMethods)
|
||
|
|
and name not in overrides):
|
||
|
|
raise ValueError(
|
||
|
|
f"Quantization method {name} is an override but "
|
||
|
|
"is has not been added to the `overrides` list "
|
||
|
|
"above. This is necessary to ensure that the "
|
||
|
|
"overrides are checked in order of preference.")
|
||
|
|
quant_method = quantization_override
|
||
|
|
self.quantization = quantization_override
|
||
|
|
break
|
||
|
|
|
||
|
|
# Verify quantization configurations.
|
||
|
|
if self.quantization is None:
|
||
|
|
self.quantization = quant_method
|
||
|
|
elif self.quantization != quant_method:
|
||
|
|
raise ValueError(
|
||
|
|
"Quantization method specified in the model config "
|
||
|
|
f"({quant_method}) does not match the quantization "
|
||
|
|
f"method specified in the `quantization` argument "
|
||
|
|
f"({self.quantization}).")
|
||
|
|
|
||
|
|
if self.quantization is not None:
|
||
|
|
if self.quantization not in supported_quantization:
|
||
|
|
raise ValueError(
|
||
|
|
f"Unknown quantization method: {self.quantization}. Must "
|
||
|
|
f"be one of {supported_quantization}.")
|
||
|
|
from vllm.platforms import current_platform
|
||
|
|
current_platform.verify_quantization(self.quantization)
|
||
|
|
if self.quantization not in optimized_quantization_methods:
|
||
|
|
logger.warning(
|
||
|
|
"%s quantization is not fully "
|
||
|
|
"optimized yet. The speed can be slower than "
|
||
|
|
"non-quantized models.", self.quantization)
|
||
|
|
|
||
|
|
|
||
|
|
def _get_head_dtype(config: PretrainedConfig, dtype: torch.dtype,
|
||
|
|
runner_type: str) -> torch.dtype:
|
||
|
|
head_dtype: Optional[Union[str,
|
||
|
|
torch.dtype]] = getattr(config, "head_dtype",
|
||
|
|
None)
|
||
|
|
|
||
|
|
if head_dtype == "model":
|
||
|
|
return dtype
|
||
|
|
elif isinstance(head_dtype, str):
|
||
|
|
head_dtype = head_dtype.lower()
|
||
|
|
if head_dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||
|
|
raise ValueError(f"Unknown dtype: {head_dtype!r}")
|
||
|
|
return _STR_DTYPE_TO_TORCH_DTYPE[head_dtype]
|
||
|
|
elif isinstance(head_dtype, torch.dtype):
|
||
|
|
return head_dtype
|
||
|
|
elif head_dtype is None:
|
||
|
|
if torch.float32 not in current_platform.supported_dtypes:
|
||
|
|
return dtype
|
||
|
|
if runner_type == "pooling":
|
||
|
|
return torch.float16
|
||
|
|
return dtype
|
||
|
|
else:
|
||
|
|
raise ValueError(f"Unknown dtype: {head_dtype}")
|