[model] added support for w8a8int8 used by neuralmagic/Qwen2-0.5B-Ins… (#9642)
This commit is contained in:
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||
|
||||
@@ -2,10 +2,12 @@
|
||||
|
||||
from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
"CompressedTensorsW8A16Fp8",
|
||||
"CompressedTensorsW8A8Int8",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
from torch.nn import Parameter
|
||||
|
||||
from sglang.srt.layers.parameter import (
|
||||
ChannelQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
PerTensorScaleParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
||||
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
from sgl_kernel import int8_scaled_mm
|
||||
|
||||
|
||||
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
||||
|
||||
def __init__(
|
||||
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
||||
):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
self.input_symmetric = input_symmetric
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# lovelace and up
|
||||
return 89
|
||||
|
||||
def process_weights_after_loading(self, layer) -> None:
|
||||
# If per tensor, when we have a fused module (e.g. QKV) with per
|
||||
# tensor scales (thus N scales being passed to the kernel),
|
||||
# requantize so we can always run per channel
|
||||
if self.strategy == QuantizationStrategy.TENSOR:
|
||||
max_w_scale, weight = requantize_with_max_scale(
|
||||
weight=layer.weight,
|
||||
weight_scale=layer.weight_scale,
|
||||
logical_widths=layer.logical_widths,
|
||||
)
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
||||
|
||||
# If channelwise, scales are already lined up, so just transpose.
|
||||
elif self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight = layer.weight
|
||||
weight_scale = layer.weight_scale.data
|
||||
|
||||
layer.weight = Parameter(weight.t(), requires_grad=False)
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
|
||||
if self.input_symmetric:
|
||||
layer.input_scale = Parameter(
|
||||
layer.input_scale.max(), requires_grad=False
|
||||
)
|
||||
else:
|
||||
input_scale = layer.input_scale
|
||||
input_zero_point = layer.input_zero_point
|
||||
|
||||
# reconstruct the ranges
|
||||
int8_traits = torch.iinfo(torch.int8)
|
||||
azps = input_zero_point.to(dtype=torch.int32)
|
||||
range_max = (input_scale * (int8_traits.max - azps)).max()
|
||||
range_min = (input_scale * (int8_traits.min - azps)).min()
|
||||
|
||||
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
||||
|
||||
# AZP loaded as int8 but used as int32
|
||||
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
||||
|
||||
layer.input_scale = Parameter(scale, requires_grad=False)
|
||||
layer.input_zero_point = Parameter(azp, requires_grad=False)
|
||||
else:
|
||||
layer.input_scale = None
|
||||
layer.input_zero_point = None
|
||||
|
||||
# azp_adj is the AZP adjustment term, used to account for weights.
|
||||
# It does not depend on scales or azp, so it is the same for
|
||||
# static and dynamic quantization.
|
||||
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
||||
if not self.input_symmetric:
|
||||
weight = layer.weight
|
||||
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
||||
if self.is_static_input_scheme:
|
||||
# cutlass_w8a8 requires azp to be folded into azp_adj
|
||||
# in the per-tensor case
|
||||
azp_adj = layer.input_zero_point * azp_adj
|
||||
layer.azp_adj = Parameter(azp_adj, requires_grad=False)
|
||||
else:
|
||||
layer.azp_adj = None
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
layer.logical_widths = output_partition_sizes
|
||||
|
||||
# WEIGHT
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition, input_size_per_partition, dtype=torch.int8
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
# WEIGHT SCALE
|
||||
if self.strategy == QuantizationStrategy.CHANNEL:
|
||||
weight_scale = ChannelQuantScaleParameter(
|
||||
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
else:
|
||||
assert self.strategy == QuantizationStrategy.TENSOR
|
||||
weight_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
|
||||
# INPUT SCALE
|
||||
if self.is_static_input_scheme:
|
||||
input_scale = PerTensorScaleParameter(
|
||||
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_scale", input_scale)
|
||||
|
||||
if not self.input_symmetric:
|
||||
# Note: compressed-tensors stores the zp using the same dtype
|
||||
# as the weights
|
||||
# AZP loaded as int8 but used as int32
|
||||
input_zero_point = PerTensorScaleParameter(
|
||||
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
||||
)
|
||||
layer.register_parameter("input_zero_point", input_zero_point)
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
# TODO: add cutlass_scaled_mm_azp support
|
||||
x_q, x_scale = per_token_quant_int8(x)
|
||||
|
||||
return int8_scaled_mm(
|
||||
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
|
||||
)
|
||||
Reference in New Issue
Block a user