first commit
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
432
vllm/model_executor/layers/quantization/quark/quark.py
Normal file
432
vllm/model_executor/layers/quantization/quark/quark.py
Normal file
@@ -0,0 +1,432 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import fnmatch
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501
|
||||
QuarkMoEMethod)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import (
|
||||
QuarkScheme, QuarkW4A4MXFP4, QuarkW8A8Fp8, QuarkW8A8Int8)
|
||||
from vllm.model_executor.layers.quantization.quark.utils import (
|
||||
deep_compare, should_ignore_layer)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkLinearMethod"]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkConfig(QuantizationConfig):
|
||||
|
||||
def __init__(self,
|
||||
quant_config: dict[str, Any],
|
||||
kv_cache_group: Optional[list[str]] = None,
|
||||
kv_cache_config: Optional[dict[str, Any]] = None,
|
||||
pack_method: str = "reorder"):
|
||||
super().__init__()
|
||||
if kv_cache_group is None:
|
||||
kv_cache_group = []
|
||||
self.quant_config = quant_config
|
||||
self.kv_cache_group = kv_cache_group
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.pack_method = pack_method
|
||||
|
||||
def get_linear_method(self) -> "QuarkLinearMethod":
|
||||
return QuarkLinearMethod(self)
|
||||
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.float16, torch.bfloat16]
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "quark"
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
# Check if the layer is skipped for quantization.
|
||||
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
||||
if should_ignore_layer(prefix,
|
||||
ignore=exclude_layers,
|
||||
fused_mapping=self.packed_modules_mapping):
|
||||
return UnquantizedLinearMethod()
|
||||
if isinstance(layer, LinearBase):
|
||||
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
||||
layer.scheme = scheme
|
||||
return QuarkLinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
return QuarkKVCacheMethod(self)
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return QuarkMoEMethod.get_moe_method(self,
|
||||
module=layer,
|
||||
layer_name=prefix)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
|
||||
export_config = config.get("export")
|
||||
if export_config is None:
|
||||
raise ValueError("The export key should be included in "
|
||||
"the configurations of Quark quantized model")
|
||||
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
|
||||
pack_method = cast(str, export_config.get("pack_method"))
|
||||
|
||||
# In the export model of quark, the quantization configuration
|
||||
# of kv_cache is stored in layer_quant_config. First, it is
|
||||
# judged whether kv_cache_group exists, and then it is judged
|
||||
# whether layer_quant_config has a quantization configuration
|
||||
# that matches kv_cache.
|
||||
if len(kv_cache_group) == 0:
|
||||
kv_cache_config = None
|
||||
else:
|
||||
kv_cache_set = set(kv_cache_group)
|
||||
layer_quant_config = cast(dict[str, Any],
|
||||
config.get("layer_quant_config"))
|
||||
layer_quant_names = list(layer_quant_config.keys())
|
||||
layer_quant_set = set(layer_quant_names)
|
||||
|
||||
if not kv_cache_set.issubset(layer_quant_set):
|
||||
raise ValueError("The Quark quantized model has the "
|
||||
"kv_cache_group parameter setting, "
|
||||
"but no kv_cache quantization settings "
|
||||
"were found in the quantization "
|
||||
"configuration.")
|
||||
|
||||
q_configs = [
|
||||
cast(dict[str, Any], layer_quant_config.get(name))
|
||||
for name in kv_cache_group
|
||||
]
|
||||
if not all(
|
||||
deep_compare(q_config, q_configs[0])
|
||||
for q_config in q_configs):
|
||||
raise ValueError(
|
||||
"The quantization method used for kv_cache should "
|
||||
"be the same, but the quantization method for the "
|
||||
"kv_cache layer in the config is different.")
|
||||
kv_cache_config = q_configs[0].get("output_tensors")
|
||||
if kv_cache_config is None:
|
||||
raise ValueError(
|
||||
"The kv_cache quantization configuration is empty.")
|
||||
|
||||
# Since we have already set kv_cache quantization configurations,
|
||||
# we will remove the quantization configuration for the
|
||||
# output_tensors corresponding to the kv_cache layer.
|
||||
for q_config in q_configs:
|
||||
q_config["output_tensors"] = None
|
||||
|
||||
# In case q_proj output is also quantized, remove the configuration
|
||||
# to keep qkv consistency.
|
||||
q_proj_q_config = cast(dict[str, Any],
|
||||
layer_quant_config.get("*q_proj"))
|
||||
if q_proj_q_config is not None:
|
||||
q_proj_q_config["output_tensors"] = None
|
||||
|
||||
return cls(quant_config=config,
|
||||
kv_cache_group=kv_cache_group,
|
||||
kv_cache_config=kv_cache_config,
|
||||
pack_method=pack_method)
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return []
|
||||
|
||||
def _check_scheme_supported(self,
|
||||
min_capability: int,
|
||||
error: bool = True) -> bool:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
|
||||
if capability_tuple is not None:
|
||||
capability = capability_tuple.to_int()
|
||||
supported = capability >= min_capability
|
||||
if error and not supported:
|
||||
raise RuntimeError(
|
||||
"Quantization scheme is not supported for ",
|
||||
f"the current GPU. Min capability: {min_capability}. ",
|
||||
f"Current capability: {capability}.")
|
||||
return supported
|
||||
else:
|
||||
return False
|
||||
|
||||
def _is_fp8_w8a8(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
# Confirm weight scheme is supported
|
||||
is_fp8_dtype = (weight_quant.get("dtype") == "fp8_e4m3"
|
||||
and input_quant.get("dtype") == "fp8_e4m3")
|
||||
is_static_weight = not weight_quant.get("is_dynamic")
|
||||
is_per_tensor_or_channel_weight = (weight_quant.get("qscheme")
|
||||
in ["per_tensor", "per_channel"])
|
||||
|
||||
if not (is_fp8_dtype and is_static_weight
|
||||
and is_per_tensor_or_channel_weight):
|
||||
return False
|
||||
|
||||
# Dynamic quantization is always supported if weights supported.
|
||||
if input_quant.get("is_dynamic"):
|
||||
return True
|
||||
|
||||
# Confirm activation scheme is supported.
|
||||
is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor")
|
||||
return is_per_tensor_activation
|
||||
|
||||
def _is_static_tensor_w8a8(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
return False
|
||||
|
||||
is_int8_dtype = (weight_quant.get("dtype") == "int8"
|
||||
and input_quant.get("dtype") == "int8")
|
||||
|
||||
is_tensor = (weight_quant.get("qscheme")
|
||||
in ["per_tensor", "per_channel"]
|
||||
and input_quant.get("qscheme") == "per_tensor")
|
||||
|
||||
is_static = (not weight_quant.get("is_dynamic")
|
||||
and not input_quant.get("is_dynamic"))
|
||||
|
||||
is_weight_symmetric = (weight_quant.get("symmetric") is True)
|
||||
|
||||
# Both symmetric and asymmetric input quantization supported.
|
||||
# Only symmetric weight quantization supported.
|
||||
return is_int8_dtype and is_tensor and is_weight_symmetric and is_static
|
||||
|
||||
def _is_mx_fp4(self, weight_quant: Optional[dict[str, Any]],
|
||||
input_quant: Optional[dict[str, Any]]) -> bool:
|
||||
# Confirm weights and input quantized.
|
||||
if weight_quant is None or input_quant is None:
|
||||
logger.debug("Quark model is not in MX-FP4 format: "
|
||||
"weight_quant or input_quant not set")
|
||||
return False
|
||||
|
||||
# Input and weight dtype needs to be fp4.
|
||||
if weight_quant.get("dtype") != "fp4" or input_quant.get(
|
||||
"dtype") != "fp4":
|
||||
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
|
||||
return False
|
||||
|
||||
# Input and weight qscheme needs to be per group.
|
||||
if weight_quant.get("qscheme") != "per_group" or input_quant.get(
|
||||
"qscheme") != "per_group":
|
||||
logger.debug("Quark model is not in MX-FP4 format: not per_group")
|
||||
return False
|
||||
|
||||
# Input and weight group size needs to be 32.
|
||||
if weight_quant.get("group_size") != 32 or input_quant.get(
|
||||
"group_size") != 32:
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not group_size=32")
|
||||
return False
|
||||
|
||||
# Activations need to use dynamic quantization.
|
||||
if input_quant.get("is_dynamic") is False:
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not activation dynamic")
|
||||
return False
|
||||
|
||||
# Activations and weight scales need to be in e8m0 format.
|
||||
if weight_quant.get("scale_format") != "e8m0" or input_quant.get(
|
||||
"scale_format") != "e8m0":
|
||||
logger.debug(
|
||||
"Quark model is not in MX-FP4 format: not scale_format e8m0")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _find_matched_config(self, layer_name: str,
|
||||
module: torch.nn.Module) -> dict[str, Any]:
|
||||
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
if proj_name in self.packed_modules_mapping:
|
||||
shard_proj_names = self.packed_modules_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
shard_configs = [
|
||||
self._find_matched_config(shard_name, module)
|
||||
for shard_name in shard_names
|
||||
]
|
||||
if not all(
|
||||
deep_compare(q_config, shard_configs[0])
|
||||
for q_config in shard_configs):
|
||||
raise ValueError(
|
||||
f"Found a different quantization configuration for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme.")
|
||||
return shard_configs[0]
|
||||
else:
|
||||
layer_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("layer_quant_config"))
|
||||
for name_pattern in layer_quant_config:
|
||||
if fnmatch.fnmatch(layer_name, name_pattern):
|
||||
return layer_quant_config[name_pattern]
|
||||
|
||||
layer_type = cast(str, type(module))
|
||||
layer_type_quant_config = cast(
|
||||
dict[str, Any],
|
||||
self.quant_config.get("layer_type_quant_config"))
|
||||
if layer_type in layer_type_quant_config:
|
||||
return layer_type_quant_config[layer_type]
|
||||
|
||||
global_quant_config = cast(
|
||||
dict[str, Any], self.quant_config.get("global_quant_config"))
|
||||
return global_quant_config
|
||||
|
||||
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
||||
if config.get("output_tensors") or config.get("bias"):
|
||||
raise NotImplementedError(
|
||||
"Currently, Quark models with output_tensors "
|
||||
"and bias quantized are not supported")
|
||||
weight_config = cast(dict[str, Any], config.get("weight"))
|
||||
input_config = cast(dict[str, Any], config.get("input_tensors"))
|
||||
|
||||
if self._is_fp8_w8a8(weight_config, input_config):
|
||||
is_fp8_w8a8_supported = self._check_scheme_supported(
|
||||
QuarkW8A8Fp8.get_min_capability(), error=False)
|
||||
if is_fp8_w8a8_supported:
|
||||
return QuarkW8A8Fp8(weight_config, input_config)
|
||||
elif self._is_static_tensor_w8a8(weight_config, input_config):
|
||||
weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
return QuarkW8A8Int8(qscheme=weight_qscheme,
|
||||
is_static_input_scheme=True,
|
||||
input_symmetric=input_config.get("symmetric"))
|
||||
elif self._is_mx_fp4(weight_config, input_config):
|
||||
return QuarkW4A4MXFP4(weight_config, input_config)
|
||||
|
||||
raise NotImplementedError("No quark compatible scheme was found. "
|
||||
f"Weight config: {weight_config}, "
|
||||
f"Input config: {input_config}")
|
||||
|
||||
def get_scheme(self, layer: torch.nn.Module,
|
||||
layer_name: str) -> "QuarkScheme":
|
||||
|
||||
layer_quant_config = self._find_matched_config(layer_name, layer)
|
||||
|
||||
# Find the quant_scheme
|
||||
scheme = self._get_scheme_from_config(layer_quant_config)
|
||||
# Raise error if device does not support the scheme
|
||||
# (e.g. fp8 needs ada lovelace)
|
||||
self._check_scheme_supported(scheme.get_min_capability())
|
||||
|
||||
return scheme
|
||||
|
||||
def get_cache_scale(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Check whether the param name matches the format for k/v cache scales
|
||||
in quark. If this is the case, return its equivalent param name
|
||||
expected by vLLM
|
||||
|
||||
:param name: param name
|
||||
:return: matching param name for KV cache scale in vLLM
|
||||
"""
|
||||
if name.endswith(".output_scale") and ".k_proj" in name:
|
||||
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
||||
if name.endswith(".output_scale") and ".v_proj" in name:
|
||||
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
||||
if name.endswith(".output_scale") and ".q_proj" in name:
|
||||
return name.replace(".q_proj.output_scale", ".attn.q_scale")
|
||||
if name.endswith("self_attn.prob_output_scale"):
|
||||
return name.replace(".prob_output_scale", ".attn.prob_scale")
|
||||
|
||||
# If no matches, return None
|
||||
return None
|
||||
|
||||
|
||||
class QuarkLinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quantization_config: QuarkConfig):
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.scheme.process_weights_after_loading(layer)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""
|
||||
Use the CompressedTensorsScheme associated with each layer to create
|
||||
the necessary parameters for the layer. See LinearMethodBase for param
|
||||
details
|
||||
"""
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
layer.scheme.create_weights(
|
||||
layer=layer,
|
||||
input_size=input_size,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
output_partition_sizes=output_partition_sizes,
|
||||
output_size=output_size,
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Use the output of create_weights and the CompressedTensorsScheme
|
||||
associated with the layer to apply the forward pass with the
|
||||
layer input. See LinearMethodBase for param details
|
||||
|
||||
"""
|
||||
scheme = layer.scheme
|
||||
if scheme is None:
|
||||
raise ValueError("A scheme must be defined for each layer")
|
||||
|
||||
return scheme.apply_weights(layer, x, bias=bias)
|
||||
|
||||
|
||||
class QuarkKVCacheMethod(BaseKVCacheMethod):
|
||||
"""
|
||||
Supports loading kv-cache scaling factors from quark checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: QuarkConfig):
|
||||
self.validate_kv_cache_config(quant_config.kv_cache_config)
|
||||
super().__init__(quant_config)
|
||||
|
||||
@staticmethod
|
||||
def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
|
||||
"""
|
||||
Validator for the kv cache configuration. Useful for controlling the
|
||||
kv cache quantization schemes, that are being supported in vLLM
|
||||
:param kv_cache_config: the quark kv cache scheme
|
||||
"""
|
||||
if kv_cache_config is None:
|
||||
return
|
||||
|
||||
dtype = kv_cache_config.get("dtype")
|
||||
if dtype != "fp8_e4m3":
|
||||
raise NotImplementedError(
|
||||
"Currently supported kv cache quantization is "
|
||||
f"dtype=fp8_e4m3, however received {dtype}")
|
||||
|
||||
qscheme = kv_cache_config.get("qscheme")
|
||||
if qscheme != "per_tensor":
|
||||
raise NotImplementedError(
|
||||
"Only support per-tensor scaling factor "
|
||||
"for quark KV cache. "
|
||||
f"Expected qscheme: per_tensor, found qscheme: {qscheme}")
|
||||
561
vllm/model_executor/layers/quantization/quark/quark_moe.py
Normal file
561
vllm/model_executor/layers/quantization/quark/quark_moe.py
Normal file
@@ -0,0 +1,561 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig, fp8_w8a8_moe_quant_config,
|
||||
mxfp4_w4a4_moe_quant_config)
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
|
||||
is_rocm_aiter_moe_enabled)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
prepare_moe_fp8_layer_for_marlin)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
OCP_MX_BLOCK_SIZE)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod", "QuarkW4A4MXFp4MoEMethod"
|
||||
]
|
||||
|
||||
|
||||
class QuarkMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig):
|
||||
super().__init__(moe)
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
||||
module: torch.nn.Module,
|
||||
layer_name: str) -> "QuarkMoEMethod":
|
||||
layer_quant_config = quant_config._find_matched_config(
|
||||
layer_name, module)
|
||||
|
||||
if (layer_quant_config.get("output_tensors")
|
||||
or layer_quant_config.get("bias")):
|
||||
raise NotImplementedError("Currently, Quark models with "
|
||||
"output_tensors and bias "
|
||||
"quantized are not supported")
|
||||
weight_config = layer_quant_config.get("weight")
|
||||
input_config = layer_quant_config.get("input_tensors")
|
||||
|
||||
if quant_config._is_fp8_w8a8(weight_config, input_config):
|
||||
return QuarkW8A8Fp8MoEMethod(weight_config, input_config,
|
||||
module.moe_config)
|
||||
elif quant_config._is_mx_fp4(weight_config, input_config):
|
||||
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config,
|
||||
module.moe_config)
|
||||
else:
|
||||
raise RuntimeError("Unsupported FusedMoe scheme")
|
||||
|
||||
|
||||
class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
|
||||
self.weight_qscheme = self.weight_quant.get("qscheme")
|
||||
self.input_qscheme = self.input_quant.get("qscheme")
|
||||
per_tensor = (self.weight_qscheme == "per_tensor"
|
||||
and self.input_qscheme == "per_tensor")
|
||||
per_channel = (self.weight_qscheme == "per_channel"
|
||||
and self.input_qscheme == "per_channel")
|
||||
self.act_quant_group_shape = GroupShape.PER_TOKEN \
|
||||
if per_channel else GroupShape.PER_TENSOR
|
||||
if not (per_tensor or per_channel):
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layers, only per-tensor and per-channel "
|
||||
"scales for weights and activations are supported. Found "
|
||||
f"{self.weight_qscheme}, {self.input_qscheme}") # noqa E501
|
||||
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
if self.static_input_scales and per_channel:
|
||||
raise ValueError(
|
||||
"For FP8 Fused MoE layer, we require either per tensor or "
|
||||
"channelwise, dynamic per token quantization.")
|
||||
|
||||
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
|
||||
# kernel for fast weight-only FP8 quantization
|
||||
self.use_marlin = (not current_platform.has_device_capability(89)
|
||||
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
|
||||
# Disable marlin for rocm
|
||||
if current_platform.is_rocm():
|
||||
self.use_marlin = False
|
||||
|
||||
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
layer.intermediate_size_per_partition = intermediate_size_per_partition
|
||||
layer.hidden_size = hidden_size
|
||||
layer.num_experts = num_experts
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
params_dtype = torch.float8_e4m3fn
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
if self.weight_qscheme == "per_tensor":
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They are combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, 2, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add PER-TENSOR quantization for FusedMoE.weight_loader.
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
elif self.weight_qscheme == "per_channel":
|
||||
# quark's scale is 1 dim.
|
||||
w13_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
w2_weight_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, hidden_size, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
# INPUT_SCALES
|
||||
if self.static_input_scales:
|
||||
w13_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_input_scale", w13_input_scale)
|
||||
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
||||
|
||||
w2_input_scale = torch.nn.Parameter(torch.ones(
|
||||
num_experts, dtype=torch.float32),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_input_scale", w2_input_scale)
|
||||
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
||||
else:
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Fp8 moe kernels require a single activation scale.
|
||||
# We take the max of all the scales in case they differ.
|
||||
if self.static_input_scales:
|
||||
if (layer.w13_input_scale is None or layer.w2_input_scale is None):
|
||||
raise ValueError(
|
||||
"QuantConfig has static quantization, but found "
|
||||
"activation scales are None.")
|
||||
if (not all_close_1d(layer.w13_input_scale)
|
||||
or not all_close_1d(layer.w2_input_scale)):
|
||||
logger.warning_once(
|
||||
"Found input_scales that are not equal for "
|
||||
"fp8 MoE layer. Using the maximum across experts "
|
||||
"for each layer. ")
|
||||
layer.w13_input_scale = torch.nn.Parameter(
|
||||
layer.w13_input_scale.max(), requires_grad=False)
|
||||
layer.w2_input_scale = torch.nn.Parameter(
|
||||
layer.w2_input_scale.max(), requires_grad=False)
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
# Normalize the weights and scales
|
||||
w13_weight, w13_weight_scale, w13_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w13_weight, layer.w13_weight_scale,
|
||||
layer.w13_input_scale)
|
||||
w2_weight, w2_weight_scale, w2_input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
layer.w2_weight, layer.w2_weight_scale,
|
||||
layer.w2_input_scale)
|
||||
# Reset the parameter
|
||||
layer.w13_weight = torch.nn.Parameter(w13_weight,
|
||||
requires_grad=False)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale,
|
||||
requires_grad=False)
|
||||
if w13_input_scale is not None:
|
||||
layer.w13_input_scale = torch.nn.Parameter(w13_input_scale,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_weight,
|
||||
requires_grad=False)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
if w2_input_scale is not None:
|
||||
layer.w2_input_scale = torch.nn.Parameter(w2_input_scale,
|
||||
requires_grad=False)
|
||||
|
||||
# For per-tensor case, Fp8 moe kernel needs single weight scale
|
||||
# for w13 per expert. Use max then dequant and requant each expert.
|
||||
if self.weight_qscheme == "per_tensor":
|
||||
assert layer.w13_weight_scale is not None
|
||||
shard_size = layer.intermediate_size_per_partition
|
||||
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
||||
for expert_id in range(layer.local_num_experts):
|
||||
start = 0
|
||||
for shard_id in range(2):
|
||||
dq_weight = per_tensor_dequantize(
|
||||
layer.w13_weight[expert_id][start:start +
|
||||
shard_size, :],
|
||||
layer.w13_weight_scale[expert_id][shard_id])
|
||||
layer.w13_weight[expert_id][
|
||||
start:start + shard_size, :], _ = ops.scaled_fp8_quant(
|
||||
dq_weight, max_w13_scales[expert_id])
|
||||
start += shard_size
|
||||
|
||||
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales,
|
||||
requires_grad=False)
|
||||
# quark's scale is 1 dim.
|
||||
elif self.weight_qscheme == "per_channel":
|
||||
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
|
||||
w13_weight_scale = layer.w13_weight_scale.unsqueeze(-1)
|
||||
layer.w13_weight_scale = torch.nn.Parameter(
|
||||
w13_weight_scale, requires_grad=False)
|
||||
w2_weight_scale = layer.w2_weight_scale.unsqueeze(-1)
|
||||
layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale,
|
||||
requires_grad=False)
|
||||
# Property to determine if AITER is used
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501
|
||||
rocm_aiter_fused_experts, shuffle_weights)
|
||||
|
||||
# reshaping weights is required for aiter moe kernel.
|
||||
shuffled_w13, shuffled_w2 = shuffle_weights(
|
||||
layer.w13_weight.data, layer.w2_weight.data)
|
||||
|
||||
layer.w13_weight = torch.nn.Parameter(shuffled_w13,
|
||||
requires_grad=False)
|
||||
layer.w2_weight = torch.nn.Parameter(shuffled_w2,
|
||||
requires_grad=False)
|
||||
|
||||
self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts
|
||||
elif self.use_marlin:
|
||||
|
||||
prepare_moe_fp8_layer_for_marlin(layer, False)
|
||||
# Activations not quantized for marlin.
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
self.fused_experts_func = None
|
||||
else:
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
self.fused_experts_func = fused_experts
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return fp8_w8a8_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=layer.w13_input_scale,
|
||||
a2_scale=layer.w2_input_scale,
|
||||
per_act_token_quant=self.weight_qscheme == "per_channel",
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet.")
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return self.rocm_aiter_fused_experts_func(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
quant_config=self.moe_quant_config,
|
||||
expert_map=expert_map)
|
||||
if self.use_marlin:
|
||||
assert activation == "silu", (
|
||||
f"{activation} not supported for Marlin MoE.")
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
None,
|
||||
None,
|
||||
layer.w13_weight_scale,
|
||||
layer.w2_weight_scale,
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
quant_type_id=scalar_types.float8_e4m3fn.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
|
||||
assert self.fused_experts_func is not None
|
||||
|
||||
return self.fused_experts_func(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config)
|
||||
|
||||
|
||||
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_config: dict[str, Any],
|
||||
input_config: dict[str, Any],
|
||||
moe: FusedMoEConfig,
|
||||
):
|
||||
super().__init__(moe)
|
||||
self.weight_quant = weight_config
|
||||
self.input_quant = input_config
|
||||
|
||||
weight_qscheme = self.weight_quant.get("qscheme")
|
||||
input_qscheme = self.input_quant.get("qscheme")
|
||||
if not (weight_qscheme == "per_group"
|
||||
and input_qscheme == "per_group"):
|
||||
raise ValueError(
|
||||
"For MX(FP4) Fused MoE layers, only per-group scales "
|
||||
"for weights and activations are supported. Found "
|
||||
f"{weight_qscheme}, {input_qscheme}") # noqa E501
|
||||
|
||||
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
||||
|
||||
if self.static_input_scales:
|
||||
raise NotImplementedError(
|
||||
"QuarkW4A4MXFp4MoEMethod with static input scales is currently "
|
||||
"not implemented. Please open an issue.")
|
||||
|
||||
if not current_platform.supports_mx():
|
||||
self.emulate = True
|
||||
logger.warning_once(
|
||||
"The current platform does not support native MXFP4 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision.")
|
||||
else:
|
||||
self.emulate = True
|
||||
logger.warning_once(
|
||||
"The current platform supports native MXFP4 "
|
||||
"computation, but kernels are not yet integrated in vLLM. "
|
||||
"Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision.")
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
||||
hidden_size: int, intermediate_size_per_partition: int,
|
||||
params_dtype: torch.dtype, **extra_weight_attrs):
|
||||
|
||||
# Add the quantization method used (per tensor/grouped/channel)
|
||||
# to ensure the weight scales are loaded in properly
|
||||
extra_weight_attrs.update(
|
||||
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value})
|
||||
|
||||
params_dtype = torch.uint8
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // 2,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
|
||||
w2_weight = torch.nn.Parameter(torch.empty(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // 2,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size // OCP_MX_BLOCK_SIZE,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module) -> Optional[FusedMoEQuantConfig]:
|
||||
return mxfp4_w4a4_moe_quant_config(
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
a1_scale=None,
|
||||
a2_scale=None,
|
||||
block_shape=None,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert self.fused_experts is None
|
||||
|
||||
if enable_eplb:
|
||||
raise NotImplementedError(
|
||||
"EPLB not supported for `QuarkW4A4MXFp4MoEMethod` yet.")
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts
|
||||
|
||||
topk_weights, topk_ids, _ = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
out = fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
global_num_experts=global_num_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
expert_map=expert_map,
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
return out
|
||||
@@ -0,0 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
from .quark_w4a4_mxfp4 import QuarkW4A4MXFP4
|
||||
from .quark_w8a8_fp8 import QuarkW8A8Fp8
|
||||
from .quark_w8a8_int8 import QuarkW8A8Int8
|
||||
|
||||
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkW4A4MXFP4"]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["QuarkScheme"]
|
||||
|
||||
|
||||
class QuarkScheme(ABC):
|
||||
"""
|
||||
Abstract class used to describe the weight creation and forward pass
|
||||
of different quantization schemes supported by Quark.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
"""
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
scheme-specific dequant/quant steps/kernels should be applied.
|
||||
|
||||
:param layer: torch.nn.Module with the registered weights and
|
||||
other parameters relevant to the particular scheme.
|
||||
:param x: input to the layer
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,239 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from functools import cache
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import envs
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
OCP_MX_BLOCK_SIZE, dequant_mxfp4, quant_dequant_mxfp4)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
PackedvLLMParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@cache
|
||||
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
|
||||
return current_platform.is_rocm() \
|
||||
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM \
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
|
||||
|
||||
try:
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import gemm_afp4wfp4
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
|
||||
from vllm.utils import direct_register_custom_op
|
||||
if is_rocm_aiter_fp4_asm_gemm_enabled():
|
||||
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
|
||||
|
||||
def gemm_with_dynamic_quant(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
rocm_use_aiter_fp4_asm_gemm: bool = False,
|
||||
out_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||
x_scales: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
M = x.shape[0]
|
||||
if rocm_use_aiter_fp4_asm_gemm:
|
||||
if x_scales is None:
|
||||
# use hip quant kernel for performance
|
||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
|
||||
# 32 alignment is enough for dim0 padding of output for
|
||||
# gemm_a4w4 kernel
|
||||
y = torch.empty((M + 31) // 32 * 32,
|
||||
weight.shape[0],
|
||||
device=x_q.device,
|
||||
dtype=out_dtype)
|
||||
|
||||
gemm_a4w4(x_q,
|
||||
weight,
|
||||
x_s,
|
||||
weight_scale.view(x_s.dtype),
|
||||
y,
|
||||
bpreshuffle=True)
|
||||
return y[:M]
|
||||
else:
|
||||
if x_scales is None:
|
||||
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
y = torch.empty(x_q.shape[0],
|
||||
weight.shape[0],
|
||||
device=x_q.device,
|
||||
dtype=out_dtype)
|
||||
|
||||
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
|
||||
return y
|
||||
|
||||
def gemm_with_dynamic_quant_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
x_scales: torch.Tensor = None,
|
||||
rocm_use_aiter_fp4_asm_gemm: bool = False,
|
||||
out_dtype: Optional[torch.dtype] = torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty((*x.shape[:-1], weight.shape[0]),
|
||||
dtype=out_dtype,
|
||||
device=x.device)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="gemm_with_dynamic_quant",
|
||||
op_func=gemm_with_dynamic_quant,
|
||||
mutates_args=[],
|
||||
fake_impl=gemm_with_dynamic_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
dynamic_mxfp4_quant = gemm_afp4wfp4 = None
|
||||
|
||||
__all__ = ["QuarkW4A4MXFP4"]
|
||||
|
||||
|
||||
class QuarkW4A4MXFP4(QuarkScheme):
|
||||
|
||||
def __init__(self, weight_quant_spec: dict[str, Any],
|
||||
input_quant_spec: dict[str, Any]):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
self.emulate = not current_platform.supports_mx()
|
||||
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
|
||||
if not self.emulate and (dynamic_mxfp4_quant is None
|
||||
or gemm_afp4wfp4 is None):
|
||||
# Currently need these kernels if not emulating
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} requires AITER to be installed "
|
||||
"for non-emulation mode! Please refer to "
|
||||
"https://github.com/ROCm/aiter for installation details.")
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data,
|
||||
requires_grad=False)
|
||||
|
||||
if self.emulate:
|
||||
layer.weight_scale = torch.nn.Parameter(layer.weight_scale.data,
|
||||
requires_grad=False)
|
||||
try:
|
||||
from quark.torch.export.nn.modules import realquantizer
|
||||
from quark.torch.quantization.config.config import (
|
||||
QuantizationSpec)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"The package `amd-quark` is required to use AMD Quark "
|
||||
"MX-FP4 models. Please install it with `pip install "
|
||||
"amd-quark`.") from err
|
||||
|
||||
weight_quant_spec = QuantizationSpec.from_dict(
|
||||
self.weight_quant_spec)
|
||||
|
||||
weight_quantizer = realquantizer.get_real_quantizer(
|
||||
qspec=weight_quant_spec,
|
||||
quantizer=None,
|
||||
real_quantized=True,
|
||||
reorder=False,
|
||||
float_dtype=self.out_dtype,
|
||||
scale_shape=layer.weight_scale.shape,
|
||||
zero_point_shape=None,
|
||||
)
|
||||
weight_quantizer.scale.data = layer.weight_scale.data
|
||||
|
||||
layer.weight = torch.nn.Parameter(
|
||||
weight_quantizer(layer.weight.data).to(self.out_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.weight_scale = None
|
||||
|
||||
# This call is necessary to release the scales memory.
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
if self.rocm_use_aiter_fp4_asm_gemm:
|
||||
# shuffle weight scale
|
||||
weight_scale_shuffle = layer.weight_scale.data
|
||||
sm, sn = weight_scale_shuffle.shape
|
||||
weight_scale_shuffle = weight_scale_shuffle.view(
|
||||
sm // 32, 2, 16, sn // 8, 2, 4, 1)
|
||||
weight_scale_shuffle = weight_scale_shuffle.permute(
|
||||
0, 3, 5, 2, 4, 1, 6).contiguous()
|
||||
weight_scale_shuffle = weight_scale_shuffle.view(sm, sn)
|
||||
layer.weight_scale = torch.nn.Parameter(weight_scale_shuffle,
|
||||
requires_grad=False)
|
||||
|
||||
# shuffle weight
|
||||
weight_shuffle = layer.weight.data
|
||||
weight_shuffle = shuffle_weight(weight_shuffle,
|
||||
layout=(16, 16))
|
||||
layer.weight = torch.nn.Parameter(weight_shuffle,
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data.T.contiguous(),
|
||||
requires_grad=False)
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // 2,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
if self.emulate:
|
||||
dq_w = dequant_mxfp4(layer.weight, layer.weight_scale, x.dtype)
|
||||
x = quant_dequant_mxfp4(x)
|
||||
return F.linear(x, dq_w, bias)
|
||||
else:
|
||||
return torch.ops.vllm.gemm_with_dynamic_quant(
|
||||
x, layer.weight, layer.weight_scale,
|
||||
self.rocm_use_aiter_fp4_asm_gemm, self.out_dtype)
|
||||
@@ -0,0 +1,163 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Callable, Optional, cast
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp, normalize_e4m3fn_to_e4m3fnuz, requantize_with_max_scale)
|
||||
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkW8A8Fp8"]
|
||||
|
||||
|
||||
class QuarkW8A8Fp8(QuarkScheme):
|
||||
|
||||
def __init__(self, weight_config: dict[str, Any],
|
||||
input_config: Optional[dict[str, Any]]):
|
||||
self.weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
self.is_static_input_scheme: bool = False
|
||||
self.input_qscheme: Optional[str] = None
|
||||
if input_config is not None:
|
||||
self.is_static_input_scheme = not cast(
|
||||
bool, input_config.get("is_dynamic"))
|
||||
self.input_qscheme = cast(str, input_config.get("qscheme"))
|
||||
|
||||
per_token = (not self.is_static_input_scheme
|
||||
and self.input_qscheme == "per_channel")
|
||||
self.act_quant_group_shape = GroupShape.PER_TOKEN \
|
||||
if per_token else GroupShape.PER_TENSOR
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_quant_group_shape)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.weight_qscheme == "per_tensor":
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
else:
|
||||
max_w_scale = layer.weight_scale
|
||||
weight = layer.weight
|
||||
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.weight_qscheme == "per_channel":
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, 'input_scale', None)
|
||||
weight, weight_scale, input_scale = \
|
||||
normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale,
|
||||
requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
|
||||
weight_scale = weight_scale.view(-1, 1)
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown quantization scheme {self.weight_qscheme}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
layer.input_scale = Parameter(layer.input_scale.max(),
|
||||
requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# TODO: update create_xxx_parameter functions to return
|
||||
# the newly added parameters
|
||||
if self.weight_qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.weight_qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
# min requirement for fp8 kernels
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
input_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
return self.fp8_linear.apply(input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias)
|
||||
@@ -0,0 +1,122 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkW8A8Int8(QuarkScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool],
|
||||
input_symmetric: Optional[bool]):
|
||||
self.qscheme = qscheme
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.qscheme == "per_channel"),
|
||||
is_static_input_scheme=(self.is_static_input_scheme is True),
|
||||
input_symmetric=(self.input_symmetric is True))
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(
|
||||
scaled_mm_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(data=torch.empty(
|
||||
sum(output_partition_sizes),
|
||||
input_size_per_partition,
|
||||
dtype=torch.int8),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
ChannelQuantZPParameter = ChannelQuantScaleParameter
|
||||
weight_zero_point = ChannelQuantZPParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)),
|
||||
dtype=torch.int8),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
else:
|
||||
assert self.qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(data=torch.empty(
|
||||
len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
PerTensorZPParameter = PerTensorScaleParameter
|
||||
weight_zero_point = PerTensorZPParameter(
|
||||
data=torch.empty(len(output_partition_sizes),
|
||||
dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(data=torch.empty(
|
||||
1, dtype=torch.float32),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
input_zero_point = BasevLLMParameter(data=torch.empty(
|
||||
1, dtype=torch.int8),
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
self.kernel = kernel_type(c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
w_s_param_name="weight_scale",
|
||||
i_s_param_name="input_scale",
|
||||
i_zp_param_name="input_zero_point",
|
||||
azp_adj_param_name="azp_adj")
|
||||
|
||||
# Checkpoints are serialized in quark format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.register_parameter("weight_zero_point", None)
|
||||
delattr(layer, 'weight_zero_point')
|
||||
if self.input_symmetric:
|
||||
layer.register_parameter("input_zero_point", None)
|
||||
delattr(layer, 'input_zero_point')
|
||||
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
105
vllm/model_executor/layers/quantization/quark/utils.py
Normal file
105
vllm/model_executor/layers/quantization/quark/utils.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Optional
|
||||
|
||||
import regex as re
|
||||
|
||||
|
||||
def deep_compare(dict1: Any, dict2: Any) -> bool:
|
||||
if type(dict1) is not type(dict2):
|
||||
return False
|
||||
if isinstance(dict1, dict):
|
||||
if dict1.keys() != dict2.keys():
|
||||
return False
|
||||
return all(deep_compare(dict1[k], dict2[k]) for k in dict1)
|
||||
elif isinstance(dict1, list):
|
||||
return set(dict1) == set(dict2)
|
||||
else:
|
||||
return dict1 == dict2
|
||||
|
||||
|
||||
def should_ignore_layer(
|
||||
layer_name: Optional[str],
|
||||
ignore: Iterable[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
|
||||
) -> bool:
|
||||
if layer_name is None:
|
||||
return False
|
||||
|
||||
# layer_name = model.layers.0.self_attn.qkv_proj
|
||||
# proj_name = qkv_proj
|
||||
proj_name = layer_name.split(".")[-1]
|
||||
|
||||
# Fused layers like gate_up_proj or qkv_proj will not be fused
|
||||
# in the safetensors checkpoint. So, we convert the name
|
||||
# from the fused version to unfused + check to make sure that
|
||||
# each shard of the fused layer has the same scheme.
|
||||
if proj_name in fused_mapping:
|
||||
shard_proj_names = fused_mapping[proj_name]
|
||||
|
||||
# Convert fused_name --> [shard_names]
|
||||
shard_names = [
|
||||
layer_name.replace(proj_name, shard_proj_name)
|
||||
for shard_proj_name in shard_proj_names
|
||||
]
|
||||
|
||||
# Layer should be ignored if shards are ignored.
|
||||
should_ignore_layer = None
|
||||
for shard_name in shard_names:
|
||||
should_ignore_shard = check_equal_or_regex_match(
|
||||
layer_name=shard_name, targets=ignore)
|
||||
|
||||
# If shard_idx=0, set layer ignore to match shard.
|
||||
if should_ignore_layer is None:
|
||||
should_ignore_layer = should_ignore_shard
|
||||
|
||||
# If shard_idx=1+ confirm scheme matches prior shards.
|
||||
elif should_ignore_shard != should_ignore_layer:
|
||||
raise ValueError(f"Found a different quantization schemes for "
|
||||
f"{shard_proj_names} in {layer_name}. vLLM "
|
||||
"requires all to use the same scheme.")
|
||||
|
||||
# Unfused layers like down_proj and o_proj will match
|
||||
# the safetensors checkpoint already.
|
||||
else:
|
||||
should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name,
|
||||
targets=ignore)
|
||||
|
||||
assert should_ignore_layer is not None
|
||||
return should_ignore_layer
|
||||
|
||||
|
||||
def check_equal_or_regex_match(layer_name: str,
|
||||
targets: Iterable[str]) -> bool:
|
||||
"""
|
||||
Checks whether a layer_name is exactly equal or a regex match for
|
||||
if target starts with 're:' to any target in list.
|
||||
"""
|
||||
for target in targets:
|
||||
if _is_equal_or_regex_match(layer_name, target):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _is_equal_or_regex_match(value: str,
|
||||
target: str,
|
||||
check_contains: bool = False) -> bool:
|
||||
"""
|
||||
Checks whether a value is exactly equal or a regex match for target
|
||||
if target starts with 're:'. If check_contains is set to True,
|
||||
additionally checks if the target string is contained within the value.
|
||||
"""
|
||||
|
||||
if target.startswith("re:"):
|
||||
pattern = target[3:]
|
||||
if re.match(pattern, value):
|
||||
return True
|
||||
elif check_contains:
|
||||
if target.lower() in value.lower():
|
||||
return True
|
||||
elif target == value:
|
||||
return True
|
||||
return False
|
||||
Reference in New Issue
Block a user