init
This commit is contained in:
76
model_executor/layers/quantization/neuron_quant.py
Normal file
76
model_executor/layers/quantization/neuron_quant.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from importlib.util import find_spec
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn']
|
||||
|
||||
|
||||
class AlwaysSupportedDtypes(list):
|
||||
|
||||
def __contains__(self, item):
|
||||
return True
|
||||
|
||||
|
||||
class NeuronQuantConfig(QuantizationConfig):
|
||||
"""Int8 Quantization Config class for Neuron Backend."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dequant_dtype: str = "f16",
|
||||
quantize_method: str = "vector_dynamic",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8")
|
||||
if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST:
|
||||
raise ValueError(
|
||||
f"Neuron quantization datatype {self.quant_dtype} is not valid,"
|
||||
f" the quantization datatype should match one of the below "
|
||||
f"types {SUPPORTED_QUANT_DTYPE_LIST}")
|
||||
self.dequant_dtype = dequant_dtype
|
||||
self.quantize_method = quantize_method
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "neuron_quant"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[str]:
|
||||
# Neuron implements custom handling logic for quantization support
|
||||
return AlwaysSupportedDtypes()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
raise NotImplementedError(
|
||||
"This function should not be called with Neuron Backend")
|
||||
|
||||
@staticmethod
|
||||
def get_config_filenames() -> list[str]:
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "NeuronQuantConfig":
|
||||
quantize_method = cls.get_from_keys(config, ["quantize_method"])
|
||||
dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"])
|
||||
return cls(dequant_dtype=dequant_dtype,
|
||||
quantize_method=quantize_method)
|
||||
|
||||
def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]:
|
||||
if find_spec("transformers_neuronx") is not None:
|
||||
return self.get_quantization_config()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Neuron Quantization is only supported through"
|
||||
" transformers_neuronx.")
|
||||
|
||||
def get_quantization_config(self):
|
||||
from transformers_neuronx.config import QuantizationConfig
|
||||
return QuantizationConfig(quant_dtype=self.quant_dtype,
|
||||
dequant_dtype=self.dequant_dtype,
|
||||
quantize_method=self.quantize_method)
|
||||
Reference in New Issue
Block a user