Iluvatar-mrv100 SDK 4.3.0
This commit is contained in:
390
vllm/model_executor/layers/quantization/quark/quark.py
Normal file
390
vllm/model_executor/layers/quantization/quark/quark.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import fnmatch
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||
QuarkMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import (
|
||||
QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8)
|
||||
from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
deep_compare, should_ignore_layer)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkLinearMethod"]
|
||||
|
||||
|
||||
class QuarkConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self,
|
||||
quant_config: Dict[str, Any],
|
||||
kv_cache_group: Optional[List[str]] = None,
|
||||
kv_cache_config: Optional[Dict[str, Any]] = None,
|
||||
pack_method: str = "reorder"):
|
||||
super().__init__()
|
||||
if kv_cache_group is None:
|
||||
kv_cache_group = []
|
||||
self.quant_config = quant_config
|
||||
self.kv_cache_group = kv_cache_group
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.pack_method = pack_method
|
||||
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
return QuarkLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "quark"
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
exclude_layers = cast(List[str], self.quant_config.get("exclude"))
|
||||
if should_ignore_layer(prefix,
|
||||
ignore=exclude_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
layer.scheme = scheme
|
||||
return QuarkLinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
return QuarkKVCacheMethod(self)
|
||||
if isinstance(layer, FusedMoE):
|
||||
return QuarkMoEMethod.get_moe_method(self,
|
||||
module=layer,
|
||||
layer_name=prefix)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig":
|
||||
export_config = config.get("export")
|
||||
if export_config is None:
|
||||
raise ValueError("The export key should be included in "
|
||||
"the configurations of Quark quantized model")
|
||||
kv_cache_group = cast(List[str], export_config.get("kv_cache_group"))
|
||||
pack_method = cast(str, export_config.get("pack_method"))
|
||||
|
||||
# In the export model of quark, the quantization configuration
|
||||
# of kv_cache is stored in layer_quant_config. First, it is
|
||||
# judged whether kv_cache_group exists, and then it is judged
|
||||
# whether layer_quant_config has a quantization configuration
|
||||
# that matches kv_cache.
|
||||
if len(kv_cache_group) == 0:
|
||||
kv_cache_config = None
|
||||
else:
|
||||
kv_cache_set = set(kv_cache_group)
|
||||
layer_quant_config = cast(Dict[str, Any],
|
||||
config.get("layer_quant_config"))
|
||||
layer_quant_names = list(layer_quant_config.keys())
|
||||
layer_quant_set = set(layer_quant_names)
|
||||
|
||||
if not kv_cache_set.issubset(layer_quant_set):
|
||||
raise ValueError("The Quark quantized model has the "
|
||||
"kv_cache_group parameter setting, "
|
||||
"but no kv_cache quantization settings "
|
||||
"were found in the quantization "
|
||||
"configuration.")
|
||||
|
||||
q_configs = [
|
||||
cast(Dict[str, Any], layer_quant_config.get(name))
|
||||
for name in kv_cache_group
|
||||
]
|
||||
if not all(
|
||||
deep_compare(q_config, q_configs[0])
|
||||
for q_config in q_configs):
|
||||
raise ValueError(
|
||||
"The quantization method used for kv_cache should "
|
||||
"be the same, but the quantization method for the "
|
||||
"kv_cache layer in the config is different.")
|
||||
kv_cache_config = q_configs[0].get("output_tensors")
|
||||
if kv_cache_config is None:
|
||||
raise ValueError(
|
||||
"The kv_cache quantization configuration is empty.")
|
||||
|
||||
# Since we have already set kv_cache quantization configurations,
|
||||
# we will remove the quantization configuration for the
|
||||
# output_tensors corresponding to the kv_cache layer.
|
||||
for q_config in q_configs:
|
||||
q_config["output_tensors"] = None
|
||||
|
||||
return cls(quant_config=config,
|
||||
kv_cache_group=kv_cache_group,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pack_method=pack_method)
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self,
|
||||
min_capability: int,
|
||||
error: bool = True) -> bool:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.")
|
||||
return supported
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]],
|
||||
input_quant: Optional[Dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported
|
||||
is_fp8_dtype = (weight_quant.get("dtype") == "fp8_e4m3"
|
||||
and input_quant.get("dtype") == "fp8_e4m3")
|
||||
is_static_weight = not weight_quant.get("is_dynamic")
|
||||
is_per_tensor_or_channel_weight = (weight_quant.get("qscheme")
|
||||
in ["per_tensor", "per_channel"])
|
||||
|
||||
if not (is_fp8_dtype and is_static_weight
|
||||
and is_per_tensor_or_channel_weight):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
if input_quant.get("is_dynamic"):
|
||||
return True
|
||||
|
||||
# Confirm activation scheme is supported.
|
||||
is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor")
|
||||
return is_per_tensor_activation
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]],
|
||||
input_quant: Optional[Dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
is_int8_dtype = (weight_quant.get("dtype") == "int8"
|
||||
and input_quant.get("dtype") == "int8")
|
||||
|
||||
is_tensor = (weight_quant.get("qscheme")
|
||||
in ["per_tensor", "per_channel"]
|
||||
and input_quant.get("qscheme") == "per_tensor")
|
||||
|
||||
is_static = (not weight_quant.get("is_dynamic")
|
||||
and not input_quant.get("is_dynamic"))
|
||||
|
||||
is_weight_symmetric = (weight_quant.get("symmetric") is True)
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
||||
|
||||
def _find_matched_config(self, layer_name: str,
|
||||
module: torch.nn.Module) -> Dict[str, Any]:
|
||||
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
if proj_name in self.packed_modules_mapping:
|
||||
shard_proj_names = self.packed_modules_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
shard_configs = [
|
||||
self._find_matched_config(shard_name, module)
|
||||
for shard_name in shard_names
|
||||
]
|
||||
if not all(
|
||||
deep_compare(q_config, shard_configs[0])
|
||||
for q_config in shard_configs):
|
||||
raise ValueError(
|
||||
f"Found a different quantization configuration for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme.")
|
||||
return shard_configs[0]
|
||||
else:
|
||||
layer_quant_config = cast(
|
||||
Dict[str, Any], self.quant_config.get("layer_quant_config"))
|
||||
for name_pattern in layer_quant_config:
|
||||
if fnmatch.fnmatch(layer_name, name_pattern):
|
||||
return layer_quant_config[name_pattern]
|
||||
|
||||
layer_type = cast(str, type(module))
|
||||
layer_type_quant_config = cast(
|
||||
Dict[str, Any],
|
||||
self.quant_config.get("layer_type_quant_config"))
|
||||
if layer_type in layer_type_quant_config:
|
||||
return layer_type_quant_config[layer_type]
|
||||
|
||||
global_quant_config = cast(
|
||||
Dict[str, Any], self.quant_config.get("global_quant_config"))
|
||||
return global_quant_config
|
||||
|
||||
def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme":
|
||||
if config.get("output_tensors") or config.get("bias"):
|
||||
raise NotImplementedError(
|
||||
"Currently, Quark models with output_tensors "
|
||||
"and bias quantized are not supported")
|
||||
weight_config = cast(Dict[str, Any], config.get("weight"))
|
||||
input_config = cast(Dict[str, Any], config.get("input_tensors"))
|
||||
|
||||
if self._is_fp8_w8a8(weight_config, input_config):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
QuarkW8A8Fp8.get_min_capability(), error=False)
|
||||
if is_fp8_w8a8_supported:
|
||||
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
input_static = (input_config is not None and
|
||||
not cast(bool, input_config.get("is_dynamic")))
|
||||
return QuarkW8A8Fp8(qscheme=weight_qscheme,
|
||||
is_static_input_scheme=input_static)
|
||||
elif self._is_static_tensor_w8a8(weight_config, input_config):
|
||||
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
return QuarkW8A8Int8(qscheme=weight_qscheme,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_config.get("symmetric"))
|
||||
|
||||
raise NotImplementedError("No quark compatible scheme was found. "
|
||||
f"Weight config: {weight_config}, "
|
||||
f"Input config: {input_config}")
|
||||
|
||||
def get_scheme(self, layer: torch.nn.Module,
|
||||
layer_name: str) -> "QuarkScheme":
|
||||
|
||||
layer_quant_config = self._find_matched_config(layer_name, layer)
|
||||
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_config(layer_quant_config)
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in quark. If this is the case, return its equivalent param name
|
||||
expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if self.kv_cache_group is None or len(self.kv_cache_group) == 0:
|
||||
return None
|
||||
|
||||
kv_proj_names = [
|
||||
re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group
|
||||
]
|
||||
if name.endswith(".output_scale"):
|
||||
if len(kv_proj_names) == 1 and kv_proj_names[0] in name:
|
||||
kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale"
|
||||
return name.replace(kv_output_scale_name, ".attn.k_scale")
|
||||
|
||||
elif len(kv_proj_names) == 2:
|
||||
for kv_proj_name in kv_proj_names:
|
||||
if kv_proj_name in name and kv_proj_name == "k_proj":
|
||||
return name.replace(".k_proj.output_scale",
|
||||
".attn.k_scale")
|
||||
elif kv_proj_name in name and kv_proj_name == "v_proj":
|
||||
return name.replace(".v_proj.output_scale",
|
||||
".attn.v_scale")
|
||||
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
|
||||
class QuarkLinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quantization_config: QuarkConfig):
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.scheme.process_weights_after_loading(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""
|
||||
Use the CompressedTensorsScheme associated with each layer to create
|
||||
the necessary parameters for the layer. See LinearMethodBase for param
|
||||
details
|
||||
"""
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.scheme.create_weights(
|
||||
layer=layer,
|
||||
input_size=input_size,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Use the output of create_weights and the CompressedTensorsScheme
|
||||
associated with the layer to apply the forward pass with the
|
||||
layer input. See LinearMethodBase for param details
|
||||
|
||||
"""
|
||||
scheme = layer.scheme
|
||||
if scheme is None:
|
||||
raise ValueError("A scheme must be defined for each layer")
|
||||
return scheme.apply_weights(layer, x, bias=bias)
|
||||
|
||||
|
||||
class QuarkKVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from quark checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: QuarkConfig):
|
||||
self.validate_kv_cache_config(quant_config.kv_cache_config)
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_config(kv_cache_config: Optional[Dict[str, Any]]):
|
||||
"""
|
||||
Validator for the kv cache configuration. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
:param kv_cache_config: the quark kv cache scheme
|
||||
"""
|
||||
if kv_cache_config is None:
|
||||
return
|
||||
|
||||
dtype = kv_cache_config.get("dtype")
|
||||
if dtype != "fp8_e4m3":
|
||||
raise NotImplementedError(
|
||||
"Currently supported kv cache quantization is "
|
||||
f"dtype=fp8_e4m3, however received {dtype}")
|
||||
|
||||
qscheme = kv_cache_config.get("qscheme")
|
||||
if qscheme != "per_tensor":
|
||||
raise NotImplementedError(
|
||||
"Only support per-tensor scaling factor "
|
||||
"for quark KV cache. "
|
||||
f"Expected qscheme: per_tensor, found qscheme: {qscheme}")
|
||||
236
vllm/model_executor/layers/quantization/quark/quark_moe.py
Normal file
236
vllm/model_executor/layers/quantization/quark/quark_moe.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"]
|
||||
|
||||
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
||||
module: torch.nn.Module,
|
||||
layer_name: str) -> "QuarkMoEMethod":
|
||||
layer_quant_config = quant_config._find_matched_config(
|
||||
layer_name, module)
|
||||
|
||||
if (layer_quant_config.get("output_tensors")
|
||||
or layer_quant_config.get("bias")):
|
||||
raise NotImplementedError("Currently, Quark models with "
|
||||
"output_tensors and bias "
|
||||
"quantized are not supported")
|
||||
weight_config = layer_quant_config.get("weight")
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
if quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config)
|
||||
else:
|
||||
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||
|
||||
|
||||
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
|
||||
def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str,
|
||||
Any]):
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
|
||||
weight_qscheme = self.weight_quant.get("qscheme")
|
||||
input_qscheme = self.input_quant.get("qscheme")
|
||||
if not (weight_qscheme == "per_tensor"
|
||||
and input_qscheme == "per_tensor"):
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layers, only per-tensor scales "
|
||||
"for weights and activations are supported. Found "
|
||||
f"{weight_qscheme}, {input_qscheme}") # noqa E501
|
||||
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
2,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.static_input_scales:
|
||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
if self.static_input_scales:
|
||||
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None.")
|
||||
if (not all_close_1d(layer.w13_input_scale)
|
||||
or not all_close_1d(layer.w2_input_scale)):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer. ")
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w13_weight, layer.w13_weight_scale,
|
||||
layer.w13_input_scale)
|
||||
w2_weight, w2_weight_scale, w2_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale,
|
||||
layer.w2_input_scale)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
|
||||
requires_grad=False)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
||||
# We take the max then dequant and requant each expert.
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start:start + shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id])
|
||||
layer.w13_weight[expert_id][
|
||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
use_fp8_w8a8=True,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale)
|
||||
@@ -0,0 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
from .quark_w8a8_fp8 import QuarkW8A8Fp8
|
||||
from .quark_w8a8_int8 import QuarkW8A8Int8
|
||||
|
||||
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"]
|
||||
@@ -0,0 +1,54 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["QuarkScheme"]
|
||||
|
||||
|
||||
class QuarkScheme(ABC):
|
||||
"""
|
||||
Abstract class used to describe the weight creation and forward pass
|
||||
of different quantization schemes supported by Quark.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
"""
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
scheme-specific dequant/quant steps/kernels should be applied.
|
||||
|
||||
:param layer: torch.nn.Module with the registered weights and
|
||||
other parameters relevant to the particular scheme.
|
||||
:param x: input to the layer
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkW8A8Fp8"]
|
||||
|
||||
|
||||
class QuarkW8A8Fp8(QuarkScheme):
|
||||
|
||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
|
||||
self.qscheme = qscheme
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.qscheme == "per_tensor":
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
input_scale=layer.input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.qscheme == "per_channel":
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=layer.input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization scheme {self.qscheme}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# TODO: update create_xxx_parameter functions to return
|
||||
# the newly added parameters
|
||||
if self.qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# min requirement for fp8 kernels
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
input_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,107 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, List, Optional, Set
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkW8A8Int8(QuarkScheme):
|
||||
_kernel_backends_being_used: Set[str] = set()
|
||||
|
||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool],
|
||||
input_symmetric: Optional[bool]):
|
||||
self.qscheme = qscheme
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: List[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
self.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.qscheme == "per_channel"),
|
||||
is_static_input_scheme=(self.is_static_input_scheme is True),
|
||||
input_symmetric=(self.input_symmetric is True))
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(data=torch.empty(
|
||||
1, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
if not self.input_symmetric:
|
||||
# Note: quark stores the zp using the same dtype
|
||||
# as the weights
|
||||
# AZP loaded as int8 but used as int32
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
w_s_param_name="weight_scale",
|
||||
i_s_param_name="input_scale",
|
||||
i_zp_param_name="input_zero_point",
|
||||
azp_adj_param_name="azp_adj")
|
||||
|
||||
# Checkpoints are serialized in quark format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
102
vllm/model_executor/layers/quantization/quark/utils.py
Normal file
102
vllm/model_executor/layers/quantization/quark/utils.py
Normal file
@@ -0,0 +1,102 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import re
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Iterable, List, Mapping, Optional
|
||||
|
||||
|
||||
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
if type(dict1) is not type(dict2):
|
||||
return False
|
||||
if isinstance(dict1, dict):
|
||||
if dict1.keys() != dict2.keys():
|
||||
return False
|
||||
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
|
||||
elif isinstance(dict1, list):
|
||||
return set(dict1) == set(dict2)
|
||||
else:
|
||||
return dict1 == dict2
|
||||
|
||||
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str],
|
||||
fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||
# proj_name = qkv_proj
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
|
||||
# Layer should be ignored if shards are ignored.
|
||||
should_ignore_layer = None
|
||||
for shard_name in shard_names:
|
||||
should_ignore_shard = check_equal_or_regex_match(
|
||||
layer_name=shard_name, targets=ignore)
|
||||
|
||||
# If shard_idx=0, set layer ignore to match shard.
|
||||
if should_ignore_layer is None:
|
||||
should_ignore_layer = should_ignore_shard
|
||||
|
||||
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||
elif should_ignore_shard != should_ignore_layer:
|
||||
raise ValueError(f"Found a different quantization schemes for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme.")
|
||||
|
||||
# Unfused layers like down_proj and o_proj will match
|
||||
# the safetensors checkpoint already.
|
||||
else:
|
||||
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
|
||||
targets=ignore)
|
||||
|
||||
assert should_ignore_layer is not None
|
||||
return should_ignore_layer
|
||||
|
||||
|
||||
def check_equal_or_regex_match(layer_name: str,
|
||||
targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(layer_name, target):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(value: str,
|
||||
target: str,
|
||||
check_contains: bool = False) -> bool:
|
||||
"""
|
||||
Checks whether a value is exactly equal or a regex match for target
|
||||
if target starts with 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
"""
|
||||
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return True
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return True
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user