Sync from v0.13
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from .quark_ocp_mx import QuarkOCP_MX
|
||||
from .quark_scheme import QuarkScheme
|
||||
from .quark_w8a8_fp8 import QuarkW8A8Fp8
|
||||
from .quark_w8a8_int8 import QuarkW8A8Int8
|
||||
|
||||
__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8", "QuarkOCP_MX"]
|
||||
@@ -0,0 +1,343 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from fractions import Fraction
|
||||
from functools import cache, partial
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm import envs
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
|
||||
dequant_mxfp4,
|
||||
quant_dequant_mxfp4,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import (
|
||||
dequant_mxfp6,
|
||||
quant_dequant_mxfp6,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
|
||||
OCP_MX_BLOCK_SIZE,
|
||||
OCP_MX_Scheme,
|
||||
)
|
||||
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .quark_scheme import QuarkScheme
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# TODO: move registration of custom op to aiter_ops.py
|
||||
# `from vllm._aiter_ops import rocm_aiter_ops`
|
||||
# use `rocm_aiter_ops.is_asm_fp4_gemm_dynamic_quant_enabled()`
|
||||
# for envs checks which does not require @cache anymore.
|
||||
# triton kernel is torch compile compatible.
|
||||
# does not require direct registration.
|
||||
# use `rocm_aiter_ops.triton_fp4_gemm_dynamic_qaunt`.
|
||||
@cache
|
||||
def is_rocm_aiter_fp4_asm_gemm_enabled() -> bool:
|
||||
return (
|
||||
current_platform.is_rocm()
|
||||
and envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM
|
||||
and envs.VLLM_ROCM_USE_AITER
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
from aiter.ops.shuffle import shuffle_weight
|
||||
from aiter.ops.triton.gemm_afp4wfp4 import (
|
||||
gemm_afp4wfp4,
|
||||
gemm_afp4wfp4_preshuffled_weight_scales,
|
||||
)
|
||||
from aiter.ops.triton.quant import dynamic_mxfp4_quant
|
||||
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
|
||||
if is_rocm_aiter_fp4_asm_gemm_enabled():
|
||||
from aiter import gemm_a4w4, per_1x32_f4_quant_hip
|
||||
|
||||
def gemm_with_dynamic_quant(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
rocm_use_aiter_fp4_asm_gemm: bool = False,
|
||||
out_dtype: torch.dtype | None = torch.bfloat16,
|
||||
x_scales: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
M = x.shape[0]
|
||||
N = weight.shape[0]
|
||||
K = weight.shape[1]
|
||||
if rocm_use_aiter_fp4_asm_gemm:
|
||||
if M <= 64 and rocm_aiter_ops.is_triton_gemm_afp4wfp4_presh_ws_tuned(N, K):
|
||||
if x_scales is None:
|
||||
# use hip quant kernel for performance
|
||||
if M >= 32:
|
||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
||||
else:
|
||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=False)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
|
||||
if M >= 32:
|
||||
x_s = x_s.view(torch.uint8).view(x_s.shape[0] // 32, -1)
|
||||
else:
|
||||
x_s = x_s[:M, ...].view(torch.uint8)
|
||||
|
||||
y = torch.empty(M, N, device=x_q.device, dtype=out_dtype)
|
||||
gemm_afp4wfp4_preshuffled_weight_scales(
|
||||
x_q.view(torch.uint8),
|
||||
weight.view(torch.uint8).view(weight.shape[0] // 16, -1),
|
||||
x_s,
|
||||
weight_scale.view(torch.uint8).view(
|
||||
weight_scale.shape[0] // 32, -1
|
||||
),
|
||||
out_dtype,
|
||||
y,
|
||||
)
|
||||
else:
|
||||
if x_scales is None:
|
||||
# use hip quant kernel for performance
|
||||
x_q, x_s = per_1x32_f4_quant_hip(x, shuffle=True)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
|
||||
# 32 alignment is enough for dim0 padding of output for
|
||||
# gemm_a4w4 kernel
|
||||
y = torch.empty(
|
||||
(M + 31) // 32 * 32,
|
||||
weight.shape[0],
|
||||
device=x_q.device,
|
||||
dtype=out_dtype,
|
||||
)
|
||||
|
||||
gemm_a4w4(
|
||||
x_q, weight, x_s, weight_scale.view(x_s.dtype), y, bpreshuffle=True
|
||||
)
|
||||
return y[:M]
|
||||
else:
|
||||
if x_scales is None:
|
||||
x_q, x_s = dynamic_mxfp4_quant(x)
|
||||
else:
|
||||
x_q = x
|
||||
x_s = x_scales
|
||||
y = torch.empty(
|
||||
x_q.shape[0], weight.shape[0], device=x_q.device, dtype=out_dtype
|
||||
)
|
||||
|
||||
gemm_afp4wfp4(x_q, weight, x_s, weight_scale.T, out_dtype, y)
|
||||
return y
|
||||
|
||||
def gemm_with_dynamic_quant_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
x_scales: torch.Tensor = None,
|
||||
rocm_use_aiter_fp4_asm_gemm: bool = False,
|
||||
out_dtype: torch.dtype | None = torch.bfloat16,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty(
|
||||
(*x.shape[:-1], weight.shape[0]), dtype=out_dtype, device=x.device
|
||||
)
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="gemm_with_dynamic_quant",
|
||||
op_func=gemm_with_dynamic_quant,
|
||||
mutates_args=[],
|
||||
fake_impl=gemm_with_dynamic_quant_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
except (ImportError, AttributeError):
|
||||
dynamic_mxfp4_quant = gemm_afp4wfp4 = None
|
||||
|
||||
|
||||
class QuarkOCP_MX(QuarkScheme):
|
||||
def __init__(
|
||||
self, weight_quant_spec: dict[str, Any], input_quant_spec: dict[str, Any]
|
||||
):
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
self.qscheme = "per_group"
|
||||
self.weight_quant_spec = weight_quant_spec
|
||||
self.input_quant_spec = input_quant_spec
|
||||
|
||||
self.weight_dtype = weight_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
self.input_dtype = input_quant_spec["dtype"].replace("fp", "mxfp")
|
||||
|
||||
self.ocp_mx_scheme = OCP_MX_Scheme.from_quant_dtype(
|
||||
self.input_dtype, self.weight_dtype
|
||||
)
|
||||
|
||||
if self.weight_dtype == "mxfp4":
|
||||
self.packed_factor: int | Fraction = 2
|
||||
self.dequant_func = dequant_mxfp4
|
||||
else:
|
||||
self.packed_factor = Fraction(numerator=8, denominator=6)
|
||||
self.dequant_func = partial(
|
||||
dequant_mxfp6, quant_dtype=self.weight_dtype.replace("mx", "")
|
||||
)
|
||||
|
||||
if self.input_dtype == "mxfp4":
|
||||
self.quant_dequant_func = quant_dequant_mxfp4
|
||||
else:
|
||||
self.quant_dequant_func = partial(
|
||||
quant_dequant_mxfp6, quant_dtype=self.input_dtype.replace("mx", "")
|
||||
)
|
||||
|
||||
self.static_input_scales = not input_quant_spec.get("is_dynamic")
|
||||
|
||||
if self.static_input_scales:
|
||||
raise NotImplementedError(
|
||||
"QuarkOCP_MX with static input scales is currently not "
|
||||
"implemented. Please open an issue."
|
||||
)
|
||||
|
||||
# TODO: integrate (or test) mixed-precision kernel.
|
||||
self.emulate = not current_platform.supports_mx() or (
|
||||
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
|
||||
)
|
||||
|
||||
self.rocm_use_aiter_fp4_asm_gemm = is_rocm_aiter_fp4_asm_gemm_enabled()
|
||||
|
||||
if not self.emulate and (dynamic_mxfp4_quant is None or gemm_afp4wfp4 is None):
|
||||
# Currently need these kernels if not emulating
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} requires AITER to be installed "
|
||||
"for non-emulation mode! Please refer to "
|
||||
"https://github.com/ROCm/aiter for installation details."
|
||||
)
|
||||
|
||||
if not current_platform.supports_mx():
|
||||
logger.warning_once(
|
||||
"The current platform does not support native MXFP4/MXFP6 "
|
||||
"computation. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision."
|
||||
)
|
||||
|
||||
if current_platform.supports_mx() and (
|
||||
self.input_dtype != "mxfp4" or self.weight_dtype != "mxfp4"
|
||||
):
|
||||
logger.warning_once(
|
||||
"The current platform supports native MXFP4/MXFP6 "
|
||||
f"computation, but kernels for input_dtype={self.input_dtype} "
|
||||
f"and weight_dtype={self.weight_dtype} are not yet integrated "
|
||||
"in vLLM. Simulated weight dequantization and activation "
|
||||
"QDQ (quantize and dequantize) will be used, with the linear "
|
||||
"layers computed in high precision."
|
||||
)
|
||||
|
||||
def get_packed_dim(self, dim: int, quant_dtype: str):
|
||||
if quant_dtype == "mxfp4":
|
||||
assert dim % 2 == 0
|
||||
return dim // 2
|
||||
elif quant_dtype in {"mxfp6_e3m2", "mxfp6_e2m3"}:
|
||||
# FP6 packs 4 * 6 = 24 bits on 3 bytes.
|
||||
assert (dim * 3) % 4 == 0
|
||||
return (dim * 3) // 4
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unsupported quant_dtype in QuarkOCP_MX.get_packed_dim, "
|
||||
f"got quant_dtype={quant_dtype}. Something is wrong, please "
|
||||
"open an issue."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
|
||||
|
||||
if self.emulate:
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data, requires_grad=False
|
||||
)
|
||||
else:
|
||||
if self.rocm_use_aiter_fp4_asm_gemm:
|
||||
# shuffle weight scale
|
||||
weight_scale_shuffle = layer.weight_scale.data
|
||||
sm, sn = weight_scale_shuffle.shape
|
||||
weight_scale_shuffle = weight_scale_shuffle.view(
|
||||
sm // 32, 2, 16, sn // 8, 2, 4, 1
|
||||
)
|
||||
weight_scale_shuffle = weight_scale_shuffle.permute(
|
||||
0, 3, 5, 2, 4, 1, 6
|
||||
).contiguous()
|
||||
weight_scale_shuffle = weight_scale_shuffle.view(sm, sn)
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
weight_scale_shuffle, requires_grad=False
|
||||
)
|
||||
|
||||
# shuffle weight
|
||||
weight_shuffle = layer.weight.data
|
||||
weight_shuffle = shuffle_weight(weight_shuffle, layout=(16, 16))
|
||||
layer.weight = torch.nn.Parameter(weight_shuffle, requires_grad=False)
|
||||
else:
|
||||
layer.weight_scale = torch.nn.Parameter(
|
||||
layer.weight_scale.data.T.contiguous(), requires_grad=False
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
self.get_packed_dim(input_size_per_partition, self.weight_dtype),
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=1,
|
||||
packed_factor=self.packed_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
weight_scale = GroupQuantScaleParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition // OCP_MX_BLOCK_SIZE,
|
||||
dtype=torch.uint8,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if self.emulate:
|
||||
dq_w = self.dequant_func(layer.weight, layer.weight_scale, x.dtype)
|
||||
qdq_x = self.quant_dequant_func(x)
|
||||
return F.linear(qdq_x, dq_w, bias)
|
||||
else:
|
||||
return torch.ops.vllm.gemm_with_dynamic_quant(
|
||||
x,
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
self.rocm_use_aiter_fp4_asm_gemm,
|
||||
self.out_dtype,
|
||||
)
|
||||
@@ -0,0 +1,55 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["QuarkScheme"]
|
||||
|
||||
|
||||
class QuarkScheme(ABC):
|
||||
"""
|
||||
Abstract class used to describe the weight creation and forward pass
|
||||
of different quantization schemes supported by Quark.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
"""
|
||||
Get minimum device capability.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, *args, **kwargs):
|
||||
"""
|
||||
Weight creation for the particular scheme. Inputs to this function
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
):
|
||||
"""
|
||||
Run the forward pass for the particular scheme. This is where
|
||||
scheme-specific dequant/quant steps/kernels should be applied.
|
||||
|
||||
:param layer: torch.nn.Module with the registered weights and
|
||||
other parameters relevant to the particular scheme.
|
||||
:param x: input to the layer
|
||||
:param bias: bias parameter
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module):
|
||||
"""
|
||||
Called after weight loading is complete for any cleanup that
|
||||
needs to occur.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@@ -0,0 +1,179 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
Fp8LinearOp,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
requantize_with_max_scale,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
__all__ = ["QuarkW8A8Fp8"]
|
||||
|
||||
|
||||
class QuarkW8A8Fp8(QuarkScheme):
|
||||
def __init__(
|
||||
self, weight_config: dict[str, Any], input_config: dict[str, Any] | None
|
||||
):
|
||||
self.weight_qscheme = cast(str, weight_config.get("qscheme"))
|
||||
self.is_static_input_scheme: bool = False
|
||||
self.input_qscheme: str | None = None
|
||||
if input_config is not None:
|
||||
self.is_static_input_scheme = not cast(bool, input_config.get("is_dynamic"))
|
||||
self.input_qscheme = cast(str, input_config.get("qscheme"))
|
||||
|
||||
per_token = (
|
||||
not self.is_static_input_scheme and self.input_qscheme == "per_channel"
|
||||
)
|
||||
self.act_quant_group_shape = (
|
||||
GroupShape.PER_TOKEN if per_token else GroupShape.PER_TENSOR
|
||||
)
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
act_quant_static=self.is_static_input_scheme,
|
||||
act_quant_group_shape=self.act_quant_group_shape,
|
||||
)
|
||||
self.out_dtype = torch.get_default_dtype()
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per tensor
|
||||
if self.weight_qscheme == "per_tensor":
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, "input_scale", None)
|
||||
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale,
|
||||
)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
else:
|
||||
max_w_scale = layer.weight_scale
|
||||
weight = layer.weight
|
||||
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=weight,
|
||||
weight_scale=max_w_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.weight_qscheme == "per_channel":
|
||||
weight = layer.weight
|
||||
|
||||
if current_platform.is_fp8_fnuz():
|
||||
input_scale = getattr(layer, "input_scale", None)
|
||||
weight, weight_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
||||
weight=weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
input_scale=input_scale,
|
||||
)
|
||||
if input_scale is not None:
|
||||
layer.input_scale = Parameter(input_scale, requires_grad=False)
|
||||
else:
|
||||
weight_scale = layer.weight_scale.data
|
||||
if self.act_quant_group_shape == GroupShape.PER_TOKEN:
|
||||
weight_scale = weight_scale.view(-1, 1)
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization scheme {self.weight_qscheme}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
layer.input_scale = Parameter(layer.input_scale.max(), requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=torch.float8_e4m3fn,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
# TODO: update create_xxx_parameter functions to return
|
||||
# the newly added parameters
|
||||
if self.weight_qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert self.weight_qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# min requirement for fp8 kernels
|
||||
weight_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
input_scale[:] = torch.finfo(torch.float32).min
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.fp8_linear.apply(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
out_dtype=self.out_dtype,
|
||||
input_scale=layer.input_scale,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -0,0 +1,139 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.kernels.scaled_mm import (
|
||||
ScaledMMLinearLayerConfig,
|
||||
choose_scaled_mm_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class QuarkW8A8Int8(QuarkScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
qscheme: str,
|
||||
is_static_input_scheme: bool | None,
|
||||
input_symmetric: bool | None,
|
||||
):
|
||||
self.qscheme = qscheme
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# turing and up
|
||||
return 75
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig(
|
||||
is_channelwise=(self.qscheme == "per_channel"),
|
||||
is_static_input_scheme=(self.is_static_input_scheme is True),
|
||||
input_symmetric=(self.input_symmetric is True),
|
||||
)
|
||||
|
||||
kernel_type = choose_scaled_mm_linear_kernel(scaled_mm_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.qscheme == "per_channel":
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
ChannelQuantZPParameter = ChannelQuantScaleParameter
|
||||
weight_zero_point = ChannelQuantZPParameter(
|
||||
data=torch.empty((sum(output_partition_sizes)), dtype=torch.int8),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert self.qscheme == "per_tensor"
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
PerTensorZPParameter = PerTensorScaleParameter
|
||||
weight_zero_point = PerTensorZPParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.int8),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_zero_point", weight_zero_point)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
input_zero_point = BasevLLMParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
self.kernel = kernel_type(
|
||||
c=scaled_mm_linear_kernel_config,
|
||||
w_q_param_name="weight",
|
||||
w_s_param_name="weight_scale",
|
||||
i_s_param_name="input_scale",
|
||||
i_zp_param_name="input_zero_point",
|
||||
azp_adj_param_name="azp_adj",
|
||||
)
|
||||
|
||||
# Checkpoints are serialized in quark format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
layer.register_parameter("weight_zero_point", None)
|
||||
delattr(layer, "weight_zero_point")
|
||||
if self.input_symmetric:
|
||||
layer.register_parameter("input_zero_point", None)
|
||||
delattr(layer, "input_zero_point")
|
||||
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
Reference in New Issue
Block a user