[1/n] chore: decouple quantization implementation from vLLM dependency (#7992)
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_experts,
|
||||
get_config_file_name,
|
||||
moe_align_block_size,
|
||||
try_get_optimal_moe_config,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||
FusedMoE,
|
||||
@@ -37,4 +38,6 @@ __all__ = [
|
||||
"fused_moe",
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
"moe_align_block_size",
|
||||
"try_get_optimal_moe_config",
|
||||
]
|
||||
|
||||
@@ -22,10 +22,6 @@ try:
|
||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQMarlin24Config,
|
||||
)
|
||||
@@ -59,7 +55,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
||||
from sglang.srt.layers.quantization.gptq import (
|
||||
GPTQConfig,
|
||||
GPTQLinearMethod,
|
||||
GPTQMarlinConfig,
|
||||
GPTQMarlinLinearMethod,
|
||||
GPTQMarlinMoEMethod,
|
||||
)
|
||||
from sglang.srt.layers.quantization.modelopt_quant import (
|
||||
|
||||
@@ -1,47 +1,55 @@
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from fractions import Fraction
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, set_weight_attrs
|
||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase, set_weight_attrs
|
||||
from sglang.srt.layers.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
permute_param_layout_,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import replace_parameter
|
||||
from sglang.srt.layers.quantization.marlin_utils import (
|
||||
apply_gptq_marlin_linear,
|
||||
check_marlin_supported,
|
||||
check_marlin_supports_shape,
|
||||
marlin_is_k_full,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace,
|
||||
marlin_moe_permute_scales,
|
||||
marlin_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks,
|
||||
marlin_sort_g_idx,
|
||||
marlin_zero_points,
|
||||
verify_marlin_supported,
|
||||
)
|
||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
||||
from sglang.srt.layers.quantization.utils import replace_parameter, unpack_cols
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
except ImportError:
|
||||
ops = None
|
||||
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
FusedMoE,
|
||||
FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported,
|
||||
GPTQMarlinLinearMethod,
|
||||
marlin_moe_permute_scales,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
if _is_cuda:
|
||||
from sgl_kernel import fused_marlin_moe
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
GPTQLinearMethod = MarlinLinearMethod = Any
|
||||
|
||||
FusedMoEMethodBase = QuantizeMethodBase
|
||||
|
||||
class scalar_types:
|
||||
uint4b8 = "uint4b8"
|
||||
uint8b128 = "uint8b128"
|
||||
|
||||
FusedMoEMethodBase = QuantizeMethodBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -54,6 +62,38 @@ def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def gptq_marlin_moe_repack(
|
||||
b_q_weight: torch.Tensor,
|
||||
perm: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
num_bits: int,
|
||||
) -> torch.Tensor:
|
||||
num_experts = b_q_weight.shape[0]
|
||||
assert size_k % 16 == 0
|
||||
output = torch.empty(
|
||||
(num_experts, size_k // 16, size_n * (num_bits // 2)),
|
||||
device=b_q_weight.device,
|
||||
dtype=b_q_weight.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = torch.ops.sgl_kernel.gptq_marlin_repack(
|
||||
b_q_weight[e], perm[e], size_k, size_n, num_bits
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarlinLinearLayerConfig:
|
||||
full_weight_shape: tuple[int, int] # [in, out]
|
||||
partition_weight_shape: tuple[int, int]
|
||||
weight_type: ScalarType
|
||||
act_type: torch.dtype
|
||||
group_size: int
|
||||
zero_points: bool
|
||||
has_g_idx: bool
|
||||
|
||||
|
||||
class GPTQConfig(QuantizationConfig):
|
||||
"""Config class for GPTQ.
|
||||
|
||||
@@ -151,11 +191,16 @@ class GPTQConfig(QuantizationConfig):
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[GPTQLinearMethod]:
|
||||
) -> Optional["LinearMethodBase"]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization import get_linear_quant_method
|
||||
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
if isinstance(layer, LinearBase):
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
|
||||
return None
|
||||
|
||||
|
||||
class GPTQMarlinConfig(QuantizationConfig):
|
||||
@@ -313,14 +358,6 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
return GPTQMarlinMoEMethod(self)
|
||||
# TODO: re-enable after SGLang syncs with vllm >= 0.7.3
|
||||
# if layer.num_experts > 32:
|
||||
# # For MoEs with many experts the moe_wna16 kernel is faster
|
||||
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
||||
# layer, prefix
|
||||
# )
|
||||
# else:
|
||||
# return GPTQMarlinMoEMethod(self)
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
|
||||
|
||||
@classmethod
|
||||
@@ -344,112 +381,439 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
if (num_bits, sym) not in cls.TYPE_MAP:
|
||||
return False
|
||||
|
||||
assert (
|
||||
VLLM_AVAILABLE
|
||||
), "vllm is not installed, to use gptq_marlin, please install vllm"
|
||||
|
||||
return check_marlin_supported(
|
||||
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
|
||||
)
|
||||
|
||||
|
||||
class MarlinConfig(QuantizationConfig):
|
||||
"""Config class for Marlin.
|
||||
class GPTQLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ.
|
||||
|
||||
Reference: https://github.com/IST-DASLab/marlin/tree/master
|
||||
Args:
|
||||
quant_config: The GPTQ quantization config.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
def __init__(self, quant_config: GPTQConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
group_size: int,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
# Group size for the quantization.
|
||||
self.group_size = group_size
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
if self.group_size != 128 and self.group_size != -1:
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
if input_size_per_partition % self.quant_config.group_size != 0:
|
||||
raise ValueError(
|
||||
"Currently, only group size 128 and -1 (channelwise) "
|
||||
"is supported for Marlin, but got group_size of "
|
||||
f"{self.group_size}"
|
||||
"The input size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size."
|
||||
)
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
|
||||
raise ValueError(
|
||||
"The output size is not aligned with the quantized "
|
||||
"weight shape. This can be caused by too large "
|
||||
"tensor parallel size."
|
||||
)
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // 4
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = 64
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = 128
|
||||
|
||||
# Max parallel problems to solve at once (improves large
|
||||
# batch performance)
|
||||
self.max_parallel = 16
|
||||
|
||||
# Permutation length used by the marlin kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"MarlinConfig(group_size={self.group_size}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> List[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
return cls(group_size, lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
is_marlin_format = check_marlin_format(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
||||
)
|
||||
|
||||
if is_marlin_format and is_valid_user_quant:
|
||||
msg = "The model is serialized in {} format. Using {} kernel.".format(
|
||||
cls.get_name(), cls.get_name()
|
||||
)
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional[MarlinLinearMethod]:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
self.use_shuffle = True
|
||||
scale_and_zero_size = input_size // group_size
|
||||
scale_and_zero_input_dim = None
|
||||
if (
|
||||
input_size != input_size_per_partition
|
||||
and self.quant_config.group_size != -1
|
||||
):
|
||||
return MarlinLinearMethod(self)
|
||||
return None
|
||||
if self.quant_config.desc_act:
|
||||
self.use_shuffle = False
|
||||
else:
|
||||
# we need to partition qzeros and scales for exllama kernel
|
||||
scale_and_zero_size = input_size_per_partition // group_size
|
||||
scale_and_zero_input_dim = 0
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
g_idx = RowvLLMParameter(
|
||||
data=torch.tensor(
|
||||
[
|
||||
i // self.quant_config.group_size
|
||||
for i in range(input_size_per_partition)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
qzeros_args = {
|
||||
"data": torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
scale_and_zero_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
if scale_and_zero_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# for torch.compile
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.g_idx = torch.nn.Parameter(layer.g_idx.data, requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# exllama needs to shuffle the weight after the weight is loaded
|
||||
# here we do the shuffle on first forward pass
|
||||
if self.use_shuffle:
|
||||
if self.quant_config.desc_act:
|
||||
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
||||
else:
|
||||
layer.g_idx.data = torch.empty(
|
||||
(0,), dtype=torch.int, device=layer.g_idx.device
|
||||
)
|
||||
ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
|
||||
output = ops.gptq_gemm(
|
||||
reshaped_x,
|
||||
layer.qweight,
|
||||
layer.qzeros,
|
||||
layer.scales,
|
||||
layer.g_idx,
|
||||
self.use_shuffle,
|
||||
self.quant_config.weight_bits,
|
||||
)
|
||||
if bias is not None:
|
||||
output.add_(bias)
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for GPTQ Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The GPTQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
|
||||
# Verify supported on platform.
|
||||
verify_marlin_supported(
|
||||
quant_type=self.quant_config.quant_type,
|
||||
group_size=self.quant_config.group_size,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
self.kernel_config = MarlinLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_config.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.quant_config.group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=self.quant_config.desc_act,
|
||||
)
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
# Determine sharding
|
||||
if marlin_repeat_scales_on_all_ranks(
|
||||
self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel
|
||||
):
|
||||
# By setting scale_dim == None, weight_loader will
|
||||
# repeat the scales on each GPU in TP>1 case.
|
||||
scales_and_zp_input_dim = None
|
||||
scales_and_zp_size = input_size // group_size
|
||||
else:
|
||||
# By setting scale_dim == 0, weight_loader will
|
||||
# shard the scales in TP>1 case.
|
||||
scales_and_zp_input_dim = 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
# Quantized weights
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.pack_factor,
|
||||
output_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Activation order
|
||||
g_idx = RowvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
qzeros_args = {
|
||||
"data": torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition // self.quant_config.pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
scales_and_zp_size,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
|
||||
if scales_and_zp_input_dim is None:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
qzeros = PackedColumnParameter(
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
qzeros = PackedvLLMParameter(
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
**qzeros_args,
|
||||
)
|
||||
|
||||
layer.register_parameter("qweight", qweight)
|
||||
layer.register_parameter("g_idx", g_idx)
|
||||
layer.register_parameter("scales", scales)
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, "qweight").device
|
||||
c = self.kernel_config
|
||||
|
||||
check_marlin_supports_shape(
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.full_weight_shape[0], # in_features
|
||||
c.group_size,
|
||||
)
|
||||
|
||||
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
self.workspace = marlin_make_workspace(device)
|
||||
|
||||
# Default names since marlin requires empty parameters for these,
|
||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||
self.w_q_name = "qweight"
|
||||
self.w_s_name = "scales"
|
||||
self.w_zp_name = "qzeros"
|
||||
self.w_gidx_name = "g_idx"
|
||||
|
||||
def _transform_param(
|
||||
layer: torch.nn.Module, name: Optional[str], fn: Callable
|
||||
) -> None:
|
||||
if name is not None and getattr(layer, name, None) is not None:
|
||||
|
||||
old_param = getattr(layer, name)
|
||||
new_param = fn(old_param)
|
||||
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
||||
# compatibility
|
||||
replace_parameter(
|
||||
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
|
||||
)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = torch.ops.sgl_kernel.gptq_marlin_repack(
|
||||
x.data.contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = marlin_permute_scales(
|
||||
x.data.contiguous(),
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
)
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name)
|
||||
)
|
||||
_transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
grouped_k = (
|
||||
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
|
||||
)
|
||||
_transform_param(
|
||||
layer,
|
||||
self.w_zp_name,
|
||||
lambda x: marlin_zero_points(
|
||||
unpack_cols(
|
||||
x.t(),
|
||||
c.weight_type.size_bits,
|
||||
grouped_k,
|
||||
c.partition_weight_shape[1],
|
||||
),
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
),
|
||||
)
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
_transform_param(layer, self.w_q_name, transform_w_q)
|
||||
_transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
c = self.kernel_config
|
||||
|
||||
def _get_weight_params(
|
||||
layer: torch.nn.Module,
|
||||
) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
Optional[torch.Tensor], # w_zp,
|
||||
Optional[torch.Tensor], # w_gidx
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, self.w_zp_name or "", None),
|
||||
getattr(layer, self.w_gidx_name or "", None),
|
||||
)
|
||||
|
||||
w_q, w_s, w_zp, w_gidx = _get_weight_params(layer)
|
||||
|
||||
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||
# None for marlin
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=w_q,
|
||||
weight_scale=w_s,
|
||||
weight_zp=w_zp, # type: ignore
|
||||
g_idx=w_gidx, # type: ignore
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=self.workspace,
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
is_k_full=self.is_k_full,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
|
||||
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
@@ -467,6 +831,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||
|
||||
intermediate_size = extra_weight_attrs.pop("intermediate_size")
|
||||
|
||||
self.is_k_full = (not self.quant_config.desc_act) or (
|
||||
@@ -644,20 +1011,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
requires_grad=False,
|
||||
)
|
||||
# Repack weights
|
||||
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
||||
marlin_w13_qweight = gptq_marlin_moe_repack(
|
||||
layer.w13_qweight,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w13_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
self.quant_config.weight_bits,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
marlin_w2_qweight = gptq_marlin_moe_repack(
|
||||
layer.w2_qweight,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w2_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
self.quant_config.weight_bits,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# Repack scales
|
||||
@@ -698,13 +1065,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
) -> torch.Tensor:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert (
|
||||
scoring_func == "softmax"
|
||||
), "Only softmax score func is supported for now."
|
||||
|
||||
# The input must currently be float16
|
||||
orig_dtype = x.dtype
|
||||
x = x.half()
|
||||
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
@@ -713,11 +1086,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
@@ -730,6 +1102,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
g_idx2=layer.w2_g_idx,
|
||||
sort_indices1=layer.w13_g_idx_sort_indices,
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
quant_type_id=self.quant_config.quant_type.id,
|
||||
num_bits=self.quant_config.weight_bits,
|
||||
is_k_full=self.is_k_full,
|
||||
).to(orig_dtype)
|
||||
|
||||
781
python/sglang/srt/layers/quantization/marlin_utils.py
Normal file
781
python/sglang/srt/layers/quantization/marlin_utils.py
Normal file
@@ -0,0 +1,781 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.linear import LinearBase, LinearMethodBase
|
||||
from sglang.srt.layers.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
||||
from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
|
||||
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from sglang.srt.utils import get_device_capability
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
except ImportError:
|
||||
ops = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GPTQ_MARLIN_TILE = 16
|
||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_MAX_PARALLEL = 16
|
||||
|
||||
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
|
||||
# In case there is a performance issue with Marlin, the variable below can be
|
||||
# changed to False, which allows Marlin to perform global reductions in fp16
|
||||
# precision (instead of fp32), and therefore, save on some memory movements.
|
||||
USE_FP32_REDUCE_DEFAULT = True
|
||||
|
||||
|
||||
# For binary size and compile time, we don't support the same types for with and
|
||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||
def query_marlin_supported_quant_types(
|
||||
has_zp: Optional[bool] = None,
|
||||
include_fp_type: bool = True,
|
||||
device_capability: Optional[int] = None,
|
||||
):
|
||||
if device_capability is None:
|
||||
major, minor = get_device_capability()
|
||||
capability = major * 10 + minor
|
||||
device_capability = -1 if capability is None else capability
|
||||
|
||||
if device_capability < 80:
|
||||
return []
|
||||
|
||||
# - has_zp is True: return quant_types that has zero points
|
||||
# - has_zp is False: return quant_types that has not zero points
|
||||
# - has_zp is None: both
|
||||
if has_zp is None:
|
||||
types0 = query_marlin_supported_quant_types(
|
||||
False, include_fp_type, device_capability
|
||||
)
|
||||
types1 = query_marlin_supported_quant_types(
|
||||
True, include_fp_type, device_capability
|
||||
)
|
||||
return types0 + types1
|
||||
|
||||
if has_zp:
|
||||
# AWQ style, unsigned + runtime zero-point
|
||||
return [scalar_types.uint4]
|
||||
else:
|
||||
# GPTQ style, unsigned + symmetric bias
|
||||
res = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
if include_fp_type:
|
||||
res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
|
||||
return res
|
||||
|
||||
|
||||
def _check_marlin_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
has_zp: bool,
|
||||
device_capability: Optional[int] = None,
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
|
||||
if device_capability is None:
|
||||
major, minor = get_device_capability()
|
||||
capability = major * 10 + minor
|
||||
device_capability = -1 if capability is None else capability
|
||||
|
||||
supported_types = query_marlin_supported_quant_types(
|
||||
has_zp, True, device_capability
|
||||
)
|
||||
|
||||
if quant_type not in supported_types:
|
||||
return (
|
||||
False,
|
||||
f"Marlin does not support weight_bits = {quant_type}. "
|
||||
f"Only types = {supported_types} "
|
||||
f"are supported (for group_size = {group_size}, "
|
||||
f"device_capability = {device_capability}, zp = {has_zp}).",
|
||||
)
|
||||
if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
return (
|
||||
False,
|
||||
f"Marlin does not support group_size = {group_size}. "
|
||||
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported.",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
|
||||
def check_marlin_supported(
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
has_zp: bool = False,
|
||||
device_capability: Optional[int] = None,
|
||||
) -> bool:
|
||||
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
|
||||
return cond
|
||||
|
||||
|
||||
def verify_marlin_supported(
|
||||
quant_type: ScalarType, group_size: int, has_zp: bool = False
|
||||
) -> None:
|
||||
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
|
||||
if not cond:
|
||||
assert err_msg is not None
|
||||
raise ValueError(err_msg)
|
||||
|
||||
|
||||
def verify_marlin_supports_shape(
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int,
|
||||
group_size: int,
|
||||
) -> None:
|
||||
|
||||
# Validate output_size_per_partition
|
||||
if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq."
|
||||
)
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible "
|
||||
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq."
|
||||
)
|
||||
|
||||
if group_size < input_size and input_size_per_partition % group_size != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = {input_size_per_partition}"
|
||||
f" is not divisible by group_size = {group_size}. "
|
||||
"Consider reducing tensor_parallel_size or running "
|
||||
"with --quantization gptq."
|
||||
)
|
||||
|
||||
|
||||
def check_marlin_supports_shape(
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
input_size: int,
|
||||
group_size: int,
|
||||
) -> tuple[bool, Optional[str]]:
|
||||
try:
|
||||
verify_marlin_supports_shape(
|
||||
output_size_per_partition, input_size_per_partition, input_size, group_size
|
||||
)
|
||||
except ValueError as e:
|
||||
return False, e.__str__()
|
||||
return True, None
|
||||
|
||||
|
||||
def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
||||
output_size_per_partition = (
|
||||
getattr(layer, "output_size_per_partition", None) or layer.output_size
|
||||
)
|
||||
input_size_per_partition = (
|
||||
getattr(layer, "input_size_per_partition", None) or layer.input_size
|
||||
)
|
||||
|
||||
return check_marlin_supports_shape(
|
||||
output_size_per_partition=output_size_per_partition,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
input_size=layer.input_size,
|
||||
group_size=group_size,
|
||||
)[0]
|
||||
|
||||
|
||||
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
||||
hidden_size = layer.hidden_size
|
||||
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
||||
# apply_router_weight_on_input is not supported for moe marlin
|
||||
supports_router_weight = not layer.apply_router_weight_on_input
|
||||
# moe marlin requires the activation to be silu
|
||||
supports_activation = layer.activation == "silu"
|
||||
|
||||
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
|
||||
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
|
||||
# moe marlin requires n % 128 == 0 and k % 64 == 0
|
||||
supports_shape = (
|
||||
hidden_size % 128 == 0
|
||||
and intermediate_size_per_partition % max(64, group_size) == 0
|
||||
)
|
||||
supports_group_size = group_size in [-1, 32, 64, 128]
|
||||
return (
|
||||
supports_shape
|
||||
and supports_group_size
|
||||
and supports_router_weight
|
||||
and supports_activation
|
||||
)
|
||||
|
||||
|
||||
def marlin_make_workspace(
|
||||
device: torch.device, max_blocks_per_sm: int = 1
|
||||
) -> torch.Tensor:
|
||||
# In the new marlin kernel, we use the num of threadblocks as workspace
|
||||
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
|
||||
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
||||
return torch.zeros(
|
||||
sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
||||
return (not act_order) or (act_order and not is_row_parallel)
|
||||
|
||||
|
||||
def marlin_repeat_scales_on_all_ranks(
|
||||
act_order: bool, group_size: int, is_row_parallel: bool
|
||||
) -> bool:
|
||||
# Need to repeat scales on every rank if act_ordering or
|
||||
# channelwise and RowParallelLinear
|
||||
is_channelwise = group_size == -1
|
||||
return act_order or (is_channelwise and is_row_parallel)
|
||||
|
||||
|
||||
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(
|
||||
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
||||
return torch.nn.Parameter(
|
||||
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
||||
)
|
||||
|
||||
|
||||
def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
||||
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
||||
|
||||
|
||||
def get_scale_perms():
|
||||
scale_perm: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: list[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
|
||||
def marlin_permute_scales(
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
||||
) -> torch.Tensor:
|
||||
|
||||
scale_perm, scale_perm_single = get_scale_perms()
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def marlin_moe_permute_scales(
|
||||
s: torch.Tensor,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
group_size: int,
|
||||
):
|
||||
num_experts = s.shape[0]
|
||||
output = torch.empty(
|
||||
(num_experts, s.shape[1], s.shape[2]),
|
||||
device=s.device,
|
||||
dtype=s.dtype,
|
||||
)
|
||||
|
||||
for e in range(num_experts):
|
||||
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
||||
return output
|
||||
|
||||
|
||||
def marlin_zero_points(
|
||||
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
) -> torch.Tensor:
|
||||
# Permute zero-points in a similar way to scales, but do not use the
|
||||
# "single" permutation, since zero-points are applied on every MMA
|
||||
scale_perm, _ = get_scale_perms()
|
||||
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
|
||||
# Interleave column dim (for the dequantize code) and pack it to int32
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
zp = zp.reshape((-1, size_n)).contiguous()
|
||||
zp = pack_cols(zp, num_bits, size_k, size_n)
|
||||
|
||||
return zp
|
||||
|
||||
|
||||
def awq_to_marlin_zero_points(
|
||||
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
) -> torch.Tensor:
|
||||
# AWQ zero-points are quantized and packed on the column dim.
|
||||
# In addition, the values are permuted based on dequantizer.
|
||||
# Here we undo both of these, and then apply marlin permutation
|
||||
# and pack it back.
|
||||
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
|
||||
|
||||
# Undo interleaving (use argsort(..) to get inverse perm)
|
||||
if num_bits == 4:
|
||||
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
|
||||
elif num_bits == 8:
|
||||
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
||||
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
||||
|
||||
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
||||
return marlin_zp
|
||||
|
||||
|
||||
def moe_awq_to_marlin_zero_points(
|
||||
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
||||
):
|
||||
num_experts = q_zp_packed.shape[0]
|
||||
output = torch.empty(
|
||||
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
|
||||
device=q_zp_packed.device,
|
||||
dtype=q_zp_packed.dtype,
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
|
||||
return output
|
||||
|
||||
|
||||
def maybe_warn_marlin_atomic_add(device, dtype):
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return
|
||||
device_capability = torch.cuda.get_device_capability(device)
|
||||
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
||||
logger.info_once(
|
||||
"You are running Marlin kernel with bf16 on GPUs before SM90. "
|
||||
"You can consider change to fp16 to achieve better performance "
|
||||
"if possible."
|
||||
)
|
||||
|
||||
|
||||
def maybe_warn_marlin_atomic_add_env():
|
||||
if torch.compiler.is_dynamo_compiling():
|
||||
return
|
||||
# TODO(yiyun): Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
if True:
|
||||
return
|
||||
# if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
|
||||
# return
|
||||
logger.info_once(
|
||||
"Marlin kernel can achieve better performance for small size_n "
|
||||
"with experimental use_atomic_add feature. "
|
||||
"You can consider set environment variable "
|
||||
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible."
|
||||
)
|
||||
|
||||
|
||||
def should_use_atomic_add_reduce(
|
||||
m: int, n: int, k: int, device: torch.device, dtype: torch.dtype
|
||||
) -> bool:
|
||||
|
||||
# the performance of atomicAdd is better than global reduce
|
||||
# only when m*n is small and k is large
|
||||
if n >= 2048 or k < 2048 or device.type != "cuda":
|
||||
return False
|
||||
|
||||
# disable atomicAdd reduce by default,
|
||||
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
|
||||
# TODO: Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False
|
||||
if not True:
|
||||
maybe_warn_marlin_atomic_add_env()
|
||||
return False
|
||||
|
||||
# sm8x doesn't support atomicAdd + bfloat16 natively
|
||||
device_capability = torch.cuda.get_device_capability(device)
|
||||
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
||||
maybe_warn_marlin_atomic_add(device, dtype)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def apply_gptq_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_zp: torch.Tensor,
|
||||
g_idx: torch.Tensor,
|
||||
g_idx_sort_indices: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
wtype: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
is_k_full: bool,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
||||
|
||||
use_atomic_add = should_use_atomic_add_reduce(
|
||||
m=reshaped_x.size(0),
|
||||
n=output_size_per_partition,
|
||||
k=reshaped_x.size(1),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
workspace,
|
||||
wtype,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
is_k_full=is_k_full,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
def apply_awq_marlin_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
weight_zp: torch.Tensor,
|
||||
g_idx: torch.Tensor,
|
||||
g_idx_sort_indices: torch.Tensor,
|
||||
workspace: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
output_size_per_partition: int,
|
||||
input_size_per_partition: int,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
||||
) -> torch.Tensor:
|
||||
reshaped_x = input.reshape(-1, input.shape[-1])
|
||||
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
||||
|
||||
use_atomic_add = should_use_atomic_add_reduce(
|
||||
m=reshaped_x.size(0),
|
||||
n=output_size_per_partition,
|
||||
k=reshaped_x.size(1),
|
||||
device=input.device,
|
||||
dtype=input.dtype,
|
||||
)
|
||||
|
||||
output = ops.gptq_marlin_gemm(
|
||||
reshaped_x,
|
||||
None,
|
||||
weight,
|
||||
weight_scale,
|
||||
None,
|
||||
weight_zp,
|
||||
g_idx,
|
||||
g_idx_sort_indices,
|
||||
workspace,
|
||||
quant_type,
|
||||
size_m=reshaped_x.shape[0],
|
||||
size_n=output_size_per_partition,
|
||||
size_k=input_size_per_partition,
|
||||
use_atomic_add=use_atomic_add,
|
||||
use_fp32_reduce=use_fp32_reduce,
|
||||
is_zp_float=False,
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output.reshape(out_shape)
|
||||
|
||||
|
||||
class MarlinConfig(QuantizationConfig):
|
||||
"""Config class for Marlin.
|
||||
|
||||
Reference: https://github.com/IST-DASLab/marlin/tree/master
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
group_size: int,
|
||||
lm_head_quantized: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# Group size for the quantization.
|
||||
self.group_size = group_size
|
||||
self.lm_head_quantized = lm_head_quantized
|
||||
if self.group_size != 128 and self.group_size != -1:
|
||||
raise ValueError(
|
||||
"Currently, only group size 128 and -1 (channelwise) "
|
||||
"is supported for Marlin, but got group_size of "
|
||||
f"{self.group_size}"
|
||||
)
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // 4
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = 64
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = 128
|
||||
|
||||
# Max parallel problems to solve at once (improves large
|
||||
# batch performance)
|
||||
self.max_parallel = 16
|
||||
|
||||
# Permutation length used by the marlin kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"MarlinConfig(group_size={self.group_size}, "
|
||||
f"lm_head_quantized={self.lm_head_quantized})"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "marlin"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "MarlinConfig":
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
||||
return cls(group_size, lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||
is_marlin_format = hf_quant_cfg.get(
|
||||
"checkpoint_format"
|
||||
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
|
||||
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
||||
)
|
||||
|
||||
if is_marlin_format and is_valid_user_quant:
|
||||
msg = "The model is serialized in {} format. Using {} kernel.".format(
|
||||
cls.get_name(), cls.get_name()
|
||||
)
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["MarlinLinearMethod"]:
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
return MarlinLinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class MarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for Marlin.
|
||||
|
||||
Args:
|
||||
quant_config: The Marlin quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: MarlinConfig):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"The params dtype must be float16, but got {params_dtype}"
|
||||
)
|
||||
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"min_n_threads = {self.quant_config.min_n_threads}."
|
||||
)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"pack_factor = {self.quant_config.pack_factor}."
|
||||
)
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"min_k_threads = {self.quant_config.min_k_threads}."
|
||||
)
|
||||
if (
|
||||
self.quant_config.group_size != -1
|
||||
and input_size_per_partition % self.quant_config.group_size != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"group_size = {self.quant_config.group_size}."
|
||||
)
|
||||
|
||||
# Check that we have at least 4 tiles horizontally in the shard
|
||||
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||
self.quant_config.tile_size**2
|
||||
)
|
||||
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||
raise ValueError("Each permutation group must reside on the same gpu")
|
||||
|
||||
# Quantized 4Bit weights packed into Int32.
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.tile_size,
|
||||
output_size_per_partition
|
||||
* self.quant_config.tile_size
|
||||
// self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
marlin_tile_size=self.quant_config.tile_size,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Determine if channelwise or not
|
||||
input_groups = (
|
||||
1
|
||||
if self.quant_config.group_size == -1
|
||||
else input_size_per_partition // self.quant_config.group_size
|
||||
)
|
||||
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
if input_groups == 1:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
|
||||
# Allocate workspace (Used for internal locking mechanism)
|
||||
max_workspace_size = (
|
||||
output_size_per_partition // self.quant_config.min_n_threads
|
||||
) * self.quant_config.max_parallel
|
||||
|
||||
workspace = BasevLLMParameter(
|
||||
data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("B", qweight)
|
||||
layer.register_parameter("s", scales)
|
||||
layer.register_parameter("workspace", workspace)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.B = torch.nn.Parameter(layer.B.data, requires_grad=False)
|
||||
layer.s = torch.nn.Parameter(layer.s.data, requires_grad=False)
|
||||
layer.workspace = torch.nn.Parameter(layer.workspace.data, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.B
|
||||
scales = layer.s
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.marlin_gemm(
|
||||
x_2d, qweight, scales, workspace, size_m, size_n, size_k
|
||||
)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
@@ -19,6 +19,36 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_weight_perm(num_bits: int):
|
||||
perm_list: List[int] = []
|
||||
for i in range(32):
|
||||
perm1: List[int] = []
|
||||
col = i // 4
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col + 8 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 256 * j for p in perm1])
|
||||
|
||||
perm = np.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = np.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
return perm
|
||||
|
||||
|
||||
class MoeWNA16Config(QuantizationConfig):
|
||||
"""Config class for MOE WNA16 (W8A16/W4A16) quantization."""
|
||||
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from sgl_kernel.scalar_type import ScalarType
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
def pack_cols(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
|
||||
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
||||
|
||||
for i in range(pack_factor):
|
||||
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def unpack_cols(
|
||||
packed_q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
assert packed_q_w.shape == (
|
||||
size_k,
|
||||
size_n // pack_factor,
|
||||
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
||||
packed_q_w.shape, size_k, size_n, pack_factor
|
||||
)
|
||||
|
||||
orig_device = packed_q_w.device
|
||||
|
||||
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
||||
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
||||
|
||||
mask = (1 << num_bits) - 1
|
||||
for i in range(pack_factor):
|
||||
vals = packed_q_w_cpu & mask
|
||||
packed_q_w_cpu >>= num_bits
|
||||
q_res[:, i::pack_factor] = vals
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def quantize_weights(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False,
|
||||
ref_zero_points_after_scales: bool = False,
|
||||
):
|
||||
assert (
|
||||
quant_type.is_integer()
|
||||
), "Floating point quantization may work but has not been tested"
|
||||
assert not zero_points or group_size is not None, (
|
||||
"to have group zero points, group_size must be provided "
|
||||
"(-1 group_size is channelwise)"
|
||||
)
|
||||
|
||||
orig_device = w.device
|
||||
orig_type = w.dtype
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size is not None and group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((group_size, -1))
|
||||
|
||||
# Compute scale for each group
|
||||
max_val = torch.max(w, 0, keepdim=True).values
|
||||
min_val = torch.min(w, 0, keepdim=True).values
|
||||
|
||||
max_q_val = quant_type.max()
|
||||
min_q_val = quant_type.min()
|
||||
|
||||
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
||||
maybe_w_zp = None
|
||||
if group_size is not None:
|
||||
if zero_points:
|
||||
assert not quant_type.is_signed() and quant_type.max() > 0
|
||||
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
||||
maybe_w_zp = (
|
||||
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
||||
)
|
||||
else:
|
||||
# If the bias is such that there are no possible negative/positive
|
||||
# values, set the max value to inf to avoid divide by 0
|
||||
w_s = torch.max(
|
||||
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
||||
)
|
||||
|
||||
# Quantize
|
||||
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
||||
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
||||
|
||||
# Compute ref (dequantized)
|
||||
# For some kernels (namely Machete) the zero-points are applied after the
|
||||
# scales are applied, for this case computing the reference in similar way
|
||||
# allows us to use tighter error tolerances in our unit tests.
|
||||
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
||||
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
||||
else:
|
||||
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
||||
|
||||
if quant_type.has_bias():
|
||||
w_q += quant_type.bias
|
||||
|
||||
# Restore original shapes
|
||||
if group_size is not None and group_size < size_k:
|
||||
|
||||
def reshape_w(w):
|
||||
w = w.reshape((group_size, -1, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((size_k, size_n)).contiguous()
|
||||
return w
|
||||
|
||||
w_q = reshape_w(w_q)
|
||||
w_ref = reshape_w(w_ref)
|
||||
w_s = w_s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
if maybe_w_zp is not None:
|
||||
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
||||
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
w_q.to(device=orig_device),
|
||||
w_s if group_size is not None else None,
|
||||
maybe_w_zp,
|
||||
)
|
||||
352
python/sglang/srt/layers/quantization/scalar_type.py
Normal file
352
python/sglang/srt/layers/quantization/scalar_type.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import functools
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
_SCALAR_TYPES_ID_MAP = {}
|
||||
|
||||
|
||||
# Mirrors enum in `core/scalar_type.hpp`
|
||||
class NanRepr(Enum):
|
||||
NONE = 0 # nans are not supported
|
||||
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
||||
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
||||
|
||||
|
||||
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
||||
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
||||
# in sync until the inductor fully supports custom C++ classes.
|
||||
@dataclass(frozen=True)
|
||||
class ScalarType:
|
||||
"""
|
||||
ScalarType can represent a wide range of floating point and integer
|
||||
types, in particular it can be used to represent sub-byte data types
|
||||
(something that torch.dtype currently does not support). It is also
|
||||
capable of representing types with a bias, i.e.:
|
||||
`stored_value = value + bias`,
|
||||
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
||||
of 8). The implementation for this class can be found in
|
||||
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
||||
with that file.
|
||||
"""
|
||||
|
||||
exponent: int
|
||||
"""
|
||||
Number of bits in the exponent if this is a floating point type
|
||||
(zero if this an integer type)
|
||||
"""
|
||||
|
||||
mantissa: int
|
||||
"""
|
||||
Number of bits in the mantissa if this is a floating point type,
|
||||
or the number bits representing an integer excluding the sign bit if
|
||||
this an integer type.
|
||||
"""
|
||||
|
||||
signed: bool
|
||||
"If the type is signed (i.e. has a sign bit)"
|
||||
|
||||
bias: int
|
||||
"""
|
||||
bias used to encode the values in this scalar type
|
||||
(value = stored_value - bias, default 0) for example if we store the
|
||||
type as an unsigned integer with a bias of 128 then the value 0 will be
|
||||
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
||||
"""
|
||||
|
||||
_finite_values_only: bool = False
|
||||
"""
|
||||
Private: if infs are supported, used `has_infs()` instead.
|
||||
"""
|
||||
|
||||
nan_repr: NanRepr = NanRepr.IEEE_754
|
||||
"""
|
||||
How NaNs are represent in this scalar type, returns NanRepr value.
|
||||
(not applicable for integer types)
|
||||
"""
|
||||
|
||||
def _floating_point_max_int(self) -> int:
|
||||
assert (
|
||||
self.mantissa <= 52 and self.exponent <= 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
|
||||
max_mantissa = (1 << self.mantissa) - 1
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
||||
max_mantissa = max_mantissa - 1
|
||||
|
||||
max_exponent = (1 << self.exponent) - 2
|
||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
|
||||
assert (
|
||||
self.exponent < 11
|
||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
||||
max_exponent = max_exponent + 1
|
||||
|
||||
# adjust the exponent to match that of a double
|
||||
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
||||
# e is the exponent bits), there is some precedent for non-standard
|
||||
# biases, example `float8_e4m3b11fnuz` here:
|
||||
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
||||
# complication we are just assuming the standard exponent bias until
|
||||
# there is a need to support non-standard biases
|
||||
exponent_bias = (1 << (self.exponent - 1)) - 1
|
||||
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
||||
|
||||
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
|
||||
|
||||
# shift the mantissa and exponent into the proper positions for an
|
||||
# IEEE double and bitwise-or them together.
|
||||
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
|
||||
|
||||
def _floating_point_max(self) -> float:
|
||||
double_raw = self._floating_point_max_int()
|
||||
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
|
||||
|
||||
def _raw_max(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
return self._floating_point_max()
|
||||
else:
|
||||
assert (
|
||||
self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
|
||||
), "Cannot represent max as an int"
|
||||
return (1 << self.mantissa) - 1
|
||||
|
||||
def _raw_min(self) -> Union[int, float]:
|
||||
if self.is_floating_point():
|
||||
assert (
|
||||
self.is_signed()
|
||||
), "We currently assume all floating point types are signed"
|
||||
sign_bit_double = 1 << 63
|
||||
|
||||
max_raw = self._floating_point_max_int()
|
||||
min_raw = max_raw | sign_bit_double
|
||||
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
|
||||
else:
|
||||
assert (
|
||||
not self.is_signed() or self.size_bits <= 64
|
||||
), "Cannot represent min as a int64_t"
|
||||
|
||||
if self.is_signed():
|
||||
return -(1 << (self.size_bits - 1))
|
||||
else:
|
||||
return 0
|
||||
|
||||
@functools.cached_property
|
||||
def id(self) -> int:
|
||||
"""
|
||||
Convert the ScalarType to an int which can be passed to pytorch custom
|
||||
ops. This layout of the int must be kept in sync with the C++
|
||||
ScalarType's from_id method.
|
||||
"""
|
||||
val = 0
|
||||
offset = 0
|
||||
|
||||
def or_and_advance(member, bit_width):
|
||||
nonlocal val
|
||||
nonlocal offset
|
||||
bit_mask = (1 << bit_width) - 1
|
||||
val = val | (int(member) & bit_mask) << offset
|
||||
offset = offset + bit_width
|
||||
|
||||
or_and_advance(self.exponent, 8)
|
||||
or_and_advance(self.mantissa, 8)
|
||||
or_and_advance(self.signed, 1)
|
||||
or_and_advance(self.bias, 32)
|
||||
or_and_advance(self._finite_values_only, 1)
|
||||
or_and_advance(self.nan_repr.value, 8)
|
||||
|
||||
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
|
||||
|
||||
_SCALAR_TYPES_ID_MAP[val] = self
|
||||
|
||||
return val
|
||||
|
||||
@property
|
||||
def size_bits(self) -> int:
|
||||
return self.exponent + self.mantissa + int(self.signed)
|
||||
|
||||
def min(self) -> Union[int, float]:
|
||||
"""
|
||||
Min representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
return self._raw_min() - self.bias
|
||||
|
||||
def max(self) -> Union[int, float]:
|
||||
"""
|
||||
Max representable value for this scalar type.
|
||||
(accounting for bias if there is one)
|
||||
"""
|
||||
return self._raw_max() - self.bias
|
||||
|
||||
def is_signed(self) -> bool:
|
||||
"""
|
||||
If the type is signed (i.e. has a sign bit), same as `signed`
|
||||
added for consistency with:
|
||||
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
||||
"""
|
||||
return self.signed
|
||||
|
||||
def is_floating_point(self) -> bool:
|
||||
"If the type is a floating point type"
|
||||
return self.exponent != 0
|
||||
|
||||
def is_integer(self) -> bool:
|
||||
"If the type is an integer type"
|
||||
return self.exponent == 0
|
||||
|
||||
def has_bias(self) -> bool:
|
||||
"If the type has a non-zero bias"
|
||||
return self.bias != 0
|
||||
|
||||
def has_infs(self) -> bool:
|
||||
"If the type is floating point and supports infinity"
|
||||
return not self._finite_values_only
|
||||
|
||||
def has_nans(self) -> bool:
|
||||
return self.nan_repr != NanRepr.NONE.value
|
||||
|
||||
def is_ieee_754(self) -> bool:
|
||||
"""
|
||||
If the type is a floating point type that follows IEEE 754
|
||||
conventions
|
||||
"""
|
||||
return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
for floating point types (leading f) the scheme is:
|
||||
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
flags:
|
||||
- no-flags: means it follows IEEE 754 conventions
|
||||
- f: means finite values only (no infinities)
|
||||
- n: means nans are supported (non-standard encoding)
|
||||
for integer types the scheme is:
|
||||
`[u]int<size_bits>[b<bias>]`
|
||||
- if bias is not present it means its zero
|
||||
"""
|
||||
if self.is_floating_point():
|
||||
ret = (
|
||||
"float"
|
||||
+ str(self.size_bits)
|
||||
+ "_e"
|
||||
+ str(self.exponent)
|
||||
+ "m"
|
||||
+ str(self.mantissa)
|
||||
)
|
||||
|
||||
if not self.is_ieee_754():
|
||||
if self._finite_values_only:
|
||||
ret = ret + "f"
|
||||
if self.nan_repr != NanRepr.NONE:
|
||||
ret = ret + "n"
|
||||
|
||||
return ret
|
||||
else:
|
||||
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
||||
if self.has_bias():
|
||||
ret = ret + "b" + str(self.bias)
|
||||
return ret
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "ScalarType." + self.__str__()
|
||||
|
||||
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
||||
# opcheck to work.
|
||||
def __len__(self) -> int:
|
||||
raise TypeError
|
||||
|
||||
#
|
||||
# Convenience Constructors
|
||||
#
|
||||
|
||||
@classmethod
|
||||
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
||||
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
||||
"""Create a unsigned integer scalar type."""
|
||||
ret = cls(0, size_bits, False, bias if bias else 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
|
||||
"""
|
||||
Create a standard floating point type
|
||||
(i.e. follows IEEE 754 conventions).
|
||||
"""
|
||||
assert mantissa > 0 and exponent > 0
|
||||
ret = cls(exponent, mantissa, True, 0)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def float_(
|
||||
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
|
||||
) -> "ScalarType":
|
||||
"""
|
||||
Create a non-standard floating point type
|
||||
(i.e. does not follow IEEE 754 conventions).
|
||||
"""
|
||||
assert mantissa > 0 and exponent > 0
|
||||
assert nan_repr != NanRepr.IEEE_754, (
|
||||
"use `float_IEEE754` constructor for floating point types that "
|
||||
"follow IEEE 754 conventions"
|
||||
)
|
||||
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
||||
ret.id # noqa B018: make sure the id is cached
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def from_id(cls, scalar_type_id: int):
|
||||
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
|
||||
raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.")
|
||||
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
|
||||
|
||||
|
||||
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
# for floating point types (leading f) the scheme is:
|
||||
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
# flags:
|
||||
# - no-flags: means it follows IEEE 754 conventions
|
||||
# - f: means finite values only (no infinities)
|
||||
# - n: means nans are supported (non-standard encoding)
|
||||
# for integer types the scheme is:
|
||||
# `[u]int<size_bits>[b<bias>]`
|
||||
# - if bias is not present it means its zero
|
||||
|
||||
|
||||
class scalar_types:
|
||||
int4 = ScalarType.int_(4, None)
|
||||
uint4 = ScalarType.uint(4, None)
|
||||
int8 = ScalarType.int_(8, None)
|
||||
uint8 = ScalarType.uint(8, None)
|
||||
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
||||
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
||||
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
||||
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
||||
|
||||
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
||||
|
||||
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||
float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
|
||||
|
||||
# "gptq" types
|
||||
uint2b2 = ScalarType.uint(2, 2)
|
||||
uint3b4 = ScalarType.uint(3, 4)
|
||||
uint4b8 = ScalarType.uint(4, 8)
|
||||
uint8b128 = ScalarType.uint(8, 128)
|
||||
|
||||
# colloquial names
|
||||
bfloat16 = float16_e8m7
|
||||
float16 = float16_e5m10
|
||||
@@ -1,11 +1,13 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||
|
||||
from types import MappingProxyType
|
||||
from typing import List, Mapping, Tuple, Union
|
||||
from typing import List, Mapping, Optional, Tuple, Union
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.scalar_type import ScalarType
|
||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_npu
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
@@ -143,3 +145,162 @@ def replace_parameter(
|
||||
if not isinstance(new, torch.nn.Parameter):
|
||||
new = torch.nn.Parameter(new, requires_grad=False)
|
||||
mod.register_parameter(name, torch.nn.Parameter(new, requires_grad=False))
|
||||
|
||||
|
||||
def get_pack_factor(num_bits):
|
||||
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
def pack_cols(
|
||||
q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
assert q_w.shape == (size_k, size_n)
|
||||
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
|
||||
orig_device = q_w.device
|
||||
|
||||
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
||||
|
||||
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
||||
|
||||
for i in range(pack_factor):
|
||||
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
def unpack_cols(
|
||||
packed_q_w: torch.Tensor,
|
||||
num_bits: int,
|
||||
size_k: int,
|
||||
size_n: int,
|
||||
):
|
||||
pack_factor = get_pack_factor(num_bits)
|
||||
assert size_n % pack_factor == 0
|
||||
assert packed_q_w.shape == (
|
||||
size_k,
|
||||
size_n // pack_factor,
|
||||
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
||||
packed_q_w.shape, size_k, size_n, pack_factor
|
||||
)
|
||||
|
||||
orig_device = packed_q_w.device
|
||||
|
||||
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
||||
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
||||
|
||||
mask = (1 << num_bits) - 1
|
||||
for i in range(pack_factor):
|
||||
vals = packed_q_w_cpu & mask
|
||||
packed_q_w_cpu >>= num_bits
|
||||
q_res[:, i::pack_factor] = vals
|
||||
|
||||
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
||||
q_res = q_res.contiguous()
|
||||
|
||||
return q_res
|
||||
|
||||
|
||||
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
||||
def quantize_weights(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: Optional[int],
|
||||
zero_points: bool = False,
|
||||
ref_zero_points_after_scales: bool = False,
|
||||
):
|
||||
assert (
|
||||
quant_type.is_integer()
|
||||
), "Floating point quantization may work but has not been tested"
|
||||
assert not zero_points or group_size is not None, (
|
||||
"to have group zero points, group_size must be provided "
|
||||
"(-1 group_size is channelwise)"
|
||||
)
|
||||
|
||||
orig_device = w.device
|
||||
orig_type = w.dtype
|
||||
size_k, size_n = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
|
||||
# Reshape to [groupsize, -1]
|
||||
if group_size is not None and group_size < size_k:
|
||||
w = w.reshape((-1, group_size, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((group_size, -1))
|
||||
|
||||
# Compute scale for each group
|
||||
max_val = torch.max(w, 0, keepdim=True).values
|
||||
min_val = torch.min(w, 0, keepdim=True).values
|
||||
|
||||
max_q_val = quant_type.max()
|
||||
min_q_val = quant_type.min()
|
||||
|
||||
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
||||
maybe_w_zp = None
|
||||
if group_size is not None:
|
||||
if zero_points:
|
||||
assert not quant_type.is_signed() and quant_type.max() > 0
|
||||
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
||||
maybe_w_zp = (
|
||||
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
||||
)
|
||||
else:
|
||||
# If the bias is such that there are no possible negative/positive
|
||||
# values, set the max value to inf to avoid divide by 0
|
||||
w_s = torch.max(
|
||||
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
||||
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
||||
)
|
||||
|
||||
# Quantize
|
||||
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
||||
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
||||
|
||||
# Compute ref (dequantized)
|
||||
# For some kernels (namely Machete) the zero-points are applied after the
|
||||
# scales are applied, for this case computing the reference in similar way
|
||||
# allows us to use tighter error tolerances in our unit tests.
|
||||
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
||||
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
||||
else:
|
||||
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
||||
|
||||
if quant_type.has_bias():
|
||||
w_q += quant_type.bias
|
||||
|
||||
# Restore original shapes
|
||||
if group_size is not None and group_size < size_k:
|
||||
|
||||
def reshape_w(w):
|
||||
w = w.reshape((group_size, -1, size_n))
|
||||
w = w.permute(1, 0, 2)
|
||||
w = w.reshape((size_k, size_n)).contiguous()
|
||||
return w
|
||||
|
||||
w_q = reshape_w(w_q)
|
||||
w_ref = reshape_w(w_ref)
|
||||
w_s = w_s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
if maybe_w_zp is not None:
|
||||
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
||||
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
||||
|
||||
return (
|
||||
w_ref.to(device=orig_device),
|
||||
w_q.to(device=orig_device),
|
||||
w_s if group_size is not None else None,
|
||||
maybe_w_zp,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user