[6/n]decouple quantization implementation from vLLM dependency (#10750)
This commit is contained in:
@@ -10,10 +10,6 @@ import torch
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
||||
CompressedTensorsW8A8Fp8MoEMethod,
|
||||
CompressedTensorsWNA16MoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||
@@ -175,51 +171,3 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
||||
return original_isinstance(obj, classinfo)
|
||||
|
||||
builtins.isinstance = patched_isinstance
|
||||
|
||||
|
||||
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||
"""
|
||||
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
||||
Convert sglang arguments to vllm arguments.
|
||||
"""
|
||||
original_apply = class_obj.apply
|
||||
sig = inspect.signature(original_apply)
|
||||
param_names = list(sig.parameters.keys())
|
||||
has_correction_bias = "e_score_correction_bias" in param_names
|
||||
|
||||
def new_apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
):
|
||||
assert activation == "silu"
|
||||
assert inplace and not no_combine
|
||||
|
||||
kwargs = {
|
||||
"self": self,
|
||||
"layer": layer,
|
||||
"x": x,
|
||||
"topk_output": topk_output,
|
||||
}
|
||||
return original_apply(**kwargs)
|
||||
|
||||
setattr(class_obj, "apply", new_apply)
|
||||
|
||||
|
||||
def monkey_patch_quant_configs():
|
||||
"""Apply all monkey patches in one place."""
|
||||
|
||||
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
||||
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
||||
|
||||
|
||||
# Only apply monkey patches if vllm is available
|
||||
if VLLM_AVAILABLE:
|
||||
monkey_patch_quant_configs()
|
||||
|
||||
@@ -30,10 +30,12 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
|
||||
CompressedTensorsMoEMethod,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsScheme,
|
||||
CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8,
|
||||
CompressedTensorsWNA16,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||
find_matched_target,
|
||||
@@ -43,23 +45,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_24 import (
|
||||
CompressedTensors24,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w4a16_sparse24 import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS,
|
||||
CompressedTensorsW4A16Sparse24,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
|
||||
WNA16_SUPPORTED_BITS,
|
||||
CompressedTensorsWNA16,
|
||||
)
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsLinearMethod"]
|
||||
@@ -380,19 +365,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
# Detect If Mixed Precision
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm"
|
||||
)
|
||||
if (
|
||||
self.quant_format == CompressionFormat.marlin_24.value
|
||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS
|
||||
):
|
||||
return CompressedTensorsW4A16Sparse24(
|
||||
strategy=weight_quant.strategy,
|
||||
num_bits=weight_quant.num_bits,
|
||||
group_size=weight_quant.group_size,
|
||||
)
|
||||
if (
|
||||
self.quant_format == CompressionFormat.pack_quantized.value
|
||||
and weight_quant.num_bits in WNA16_SUPPORTED_BITS
|
||||
@@ -403,6 +375,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder,
|
||||
)
|
||||
else:
|
||||
raise ImportError(
|
||||
"Other method (CompressedTensorsW4A16Sparse24) is not supported now"
|
||||
)
|
||||
|
||||
if is_activation_quantization_format(self.quant_format):
|
||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||
@@ -426,10 +402,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
# note: input_quant can be None
|
||||
if self._is_fp8_w8a16(weight_quant, input_quant):
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm"
|
||||
)
|
||||
is_static_input_scheme = input_quant and not input_quant.dynamic
|
||||
return CompressedTensorsW8A16Fp8(
|
||||
strategy=weight_quant.strategy,
|
||||
@@ -470,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
|
||||
# Find the "target" in the compressed-tensors config
|
||||
# that our layer conforms to.
|
||||
# TODO (@robertgshaw): add compressed-tensors as dep
|
||||
# TODO : add compressed-tensors as dep
|
||||
# so we do not have to re-write these functions
|
||||
# need to make accelerate optional in ct to do this
|
||||
|
||||
@@ -508,24 +480,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
input_quant=input_quant,
|
||||
sparsity_scheme=sparsity_scheme,
|
||||
):
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use CompressedTensors24, please install vllm"
|
||||
)
|
||||
# Have a valid sparsity scheme
|
||||
# Validate layer is supported by Cutlass 2:4 Kernel
|
||||
model_compression_config = (
|
||||
None
|
||||
if sparsity_scheme is None or sparsity_scheme.format == "dense"
|
||||
else self.config
|
||||
)
|
||||
|
||||
scheme = CompressedTensors24(
|
||||
quantized=weight_quant is not None or input_quant is not None,
|
||||
weight_quant=weight_quant,
|
||||
input_quant=input_quant,
|
||||
model_compression_config=model_compression_config,
|
||||
)
|
||||
raise ImportError("CompressedTensors24 is not supported now")
|
||||
elif weight_quant is None:
|
||||
logger.warning_once(
|
||||
"Acceleration for non-quantized schemes is "
|
||||
|
||||
@@ -6,7 +6,7 @@ import enum
|
||||
import logging
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
try:
|
||||
from sgl_kernel import fused_marlin_moe
|
||||
@@ -31,9 +31,13 @@ from sglang.srt.environ import envs
|
||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||
from sglang.srt.layers.quantization.compressed_tensors import WNA16_SUPPORTED_BITS
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
WNA16_SUPPORTED_BITS,
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack
|
||||
from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
all_close_1d,
|
||||
per_tensor_dequantize,
|
||||
@@ -42,6 +46,7 @@ from sglang.srt.layers.quantization.utils import (
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_compiler_backend,
|
||||
is_cuda,
|
||||
is_hip,
|
||||
set_weight_attrs,
|
||||
)
|
||||
@@ -57,6 +62,8 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
_is_hip = is_hip()
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||
|
||||
if _use_aiter:
|
||||
@@ -64,12 +71,9 @@ if _use_aiter:
|
||||
|
||||
from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
|
||||
|
||||
try:
|
||||
import vllm # noqa: F401
|
||||
|
||||
VLLM_AVAILABLE = True
|
||||
except ImportError:
|
||||
VLLM_AVAILABLE = False
|
||||
if _is_cuda:
|
||||
from sgl_kernel import fused_marlin_moe
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -127,10 +131,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
|
||||
input_quant = quant_config.target_scheme_map["Linear"].get("input_activations")
|
||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if not VLLM_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
|
||||
)
|
||||
|
||||
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
||||
return CompressedTensorsWNA16MoEMethod(quant_config)
|
||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
||||
@@ -432,9 +434,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
):
|
||||
if self.num_gpu_experts != -1:
|
||||
num_experts = self.num_gpu_experts
|
||||
# assert (
|
||||
# params_dtype == torch.float16
|
||||
# ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
|
||||
# Will transpose the loaded weight along the
|
||||
# intermediate and hidden dim sizes. Will
|
||||
@@ -573,44 +572,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
getattr(layer, name).copy_(new_t)
|
||||
del new_t
|
||||
|
||||
def get_scale_perms(num_bits: int):
|
||||
scale_perm: List[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
||||
scale_perm_single: List[int] = []
|
||||
for i in range(4):
|
||||
scale_perm_single.extend(
|
||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]
|
||||
)
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
def marlin_permute_scales(
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
|
||||
):
|
||||
scale_perm, scale_perm_single = get_scale_perms(num_bits)
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
return s
|
||||
|
||||
def marlin_moe_permute_scales(
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
|
||||
):
|
||||
num_experts = s.shape[0]
|
||||
output = torch.empty(
|
||||
(num_experts, s.shape[1], s.shape[2]), device=s.device, dtype=s.dtype
|
||||
)
|
||||
for e in range(num_experts):
|
||||
output[e] = marlin_permute_scales(
|
||||
s[e], size_k, size_n, group_size, num_bits
|
||||
)
|
||||
return output
|
||||
|
||||
size_k2 = layer.w2_weight_packed.shape[2]
|
||||
size_k13 = layer.w13_weight_packed.shape[2]
|
||||
|
||||
num_experts = layer.w13_weight_g_idx.shape[0]
|
||||
device = layer.w13_weight_g_idx.device
|
||||
|
||||
@@ -657,42 +618,39 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
from vllm import _custom_ops as vllm_ops
|
||||
|
||||
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
|
||||
marlin_w13_qweight = gptq_marlin_moe_repack(
|
||||
layer.w13_weight_packed,
|
||||
layer.w13_g_idx_sort_indices,
|
||||
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
||||
layer.w13_weight_packed.shape[2],
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w13_weight_packed", marlin_w13_qweight)
|
||||
marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
|
||||
replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
|
||||
marlin_w2_qweight = gptq_marlin_moe_repack(
|
||||
layer.w2_weight_packed,
|
||||
layer.w2_g_idx_sort_indices,
|
||||
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
||||
layer.w2_weight_packed.shape[2],
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w2_weight_packed", marlin_w2_qweight)
|
||||
replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
|
||||
# Repack scales
|
||||
marlin_w13_scales = marlin_moe_permute_scales(
|
||||
layer.w13_weight_scale,
|
||||
size_k13,
|
||||
layer.w13_weight_packed.shape[2],
|
||||
layer.w13_weight_scale.shape[2],
|
||||
self.group_size,
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w13_weight_scale", marlin_w13_scales)
|
||||
replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
|
||||
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
layer.w2_weight_scale,
|
||||
layer.w2_weight_scale.shape[1]
|
||||
* (self.group_size if self.group_size != -1 else self.packed_factor),
|
||||
size_k2,
|
||||
layer.w2_weight_scale.shape[2],
|
||||
self.group_size,
|
||||
self.num_bits,
|
||||
)
|
||||
replace_tensor("w2_weight_scale", marlin_w2_scales)
|
||||
replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
|
||||
|
||||
def create_moe_runner(
|
||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||
@@ -716,7 +674,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
|
||||
output = torch.ops.vllm.fused_marlin_moe(
|
||||
output = fused_marlin_moe(
|
||||
x,
|
||||
layer.w13_weight_packed,
|
||||
layer.w2_weight_packed,
|
||||
|
||||
@@ -4,10 +4,13 @@ from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
|
||||
|
||||
__all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
"CompressedTensorsW8A16Fp8",
|
||||
"CompressedTensorsW8A8Int8",
|
||||
"CompressedTensorsWNA16",
|
||||
"WNA16_SUPPORTED_BITS",
|
||||
]
|
||||
|
||||
@@ -14,25 +14,12 @@ from sglang.srt.layers.parameter import (
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from sglang.srt.layers.quantization.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import convert_to_channelwise
|
||||
|
||||
try:
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
apply_fp8_marlin_linear,
|
||||
prepare_fp8_layer_for_marlin,
|
||||
)
|
||||
|
||||
MARLIN_FP8_AVAILABLE = True
|
||||
except ImportError:
|
||||
MARLIN_FP8_AVAILABLE = False
|
||||
|
||||
def apply_fp8_marlin_linear(*args, **kwargs):
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
def prepare_fp8_layer_for_marlin(*args, **kwargs):
|
||||
raise ImportError("vllm is not installed")
|
||||
|
||||
|
||||
__all__ = ["CompressedTensorsW8A16Fp8"]
|
||||
|
||||
SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR]
|
||||
@@ -43,11 +30,6 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
||||
self.strategy = strategy
|
||||
self.is_static_input_scheme = is_static_input_scheme
|
||||
|
||||
if not MARLIN_FP8_AVAILABLE:
|
||||
raise ImportError(
|
||||
"vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere and up
|
||||
|
||||
@@ -0,0 +1,339 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import logging
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors.quantization import ActivationOrdering
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from sglang.srt.layers.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedColumnParameter,
|
||||
PackedvLLMParameter,
|
||||
RowvLLMParameter,
|
||||
permute_param_layout_,
|
||||
)
|
||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from sglang.srt.layers.quantization.marlin_utils import (
|
||||
MarlinLinearLayerConfig,
|
||||
apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape,
|
||||
marlin_is_k_full,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace,
|
||||
marlin_permute_scales,
|
||||
marlin_repeat_scales_on_all_ranks,
|
||||
marlin_sort_g_idx,
|
||||
marlin_zero_points,
|
||||
)
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
get_scalar_types,
|
||||
replace_parameter,
|
||||
unpack_cols,
|
||||
)
|
||||
from sglang.srt.utils import is_cuda
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
if _is_cuda:
|
||||
from sgl_kernel import gptq_marlin_repack
|
||||
|
||||
|
||||
ScalarType, scalar_types = get_scalar_types()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["CompressedTensorsWNA16"]
|
||||
WNA16_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128
|
||||
}
|
||||
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
|
||||
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self,
|
||||
strategy: str,
|
||||
num_bits: int,
|
||||
group_size: Optional[int] = None,
|
||||
symmetric: Optional[bool] = True,
|
||||
actorder: Optional[ActivationOrdering] = None):
|
||||
|
||||
self.pack_factor = 32 // num_bits
|
||||
self.strategy = strategy
|
||||
self.symmetric = symmetric
|
||||
self.group_size = -1 if group_size is None else group_size
|
||||
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||
|
||||
if self.group_size == -1 and self.strategy != "channel":
|
||||
raise ValueError("Marlin kernels require group quantization or "
|
||||
"channelwise quantization, but found no group "
|
||||
"size and strategy is not channelwise.")
|
||||
|
||||
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
|
||||
|
||||
self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
|
||||
if not self.symmetric else
|
||||
WNA16_SUPPORTED_TYPES_MAP[num_bits])
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere and up
|
||||
return 80
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||
input_size: int, output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype, weight_loader: Callable,
|
||||
**kwargs):
|
||||
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
self.kernel_config = MarlinLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_type,
|
||||
act_type=params_dtype,
|
||||
group_size=self.group_size,
|
||||
zero_points=not self.symmetric,
|
||||
has_g_idx=self.has_g_idx
|
||||
)
|
||||
|
||||
# If group_size is -1, we are in channelwise case.
|
||||
group_size = self.group_size if self.group_size != -1 else input_size
|
||||
row_parallel = (input_size != input_size_per_partition)
|
||||
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||
self.has_g_idx, self.group_size, row_parallel)
|
||||
|
||||
scales_and_zp_size = input_size // group_size
|
||||
|
||||
if partition_scales:
|
||||
assert input_size_per_partition % group_size == 0
|
||||
scales_and_zp_size = input_size_per_partition // group_size
|
||||
|
||||
weight = PackedvLLMParameter(input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=weight_loader,
|
||||
packed_factor=self.pack_factor,
|
||||
packed_dim=1,
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition //
|
||||
self.pack_factor,
|
||||
dtype=torch.int32,
|
||||
))
|
||||
|
||||
weight_scale_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.empty(
|
||||
output_size_per_partition,
|
||||
scales_and_zp_size,
|
||||
dtype=params_dtype,
|
||||
)
|
||||
}
|
||||
|
||||
zeros_args = {
|
||||
"weight_loader":
|
||||
weight_loader,
|
||||
"data":
|
||||
torch.zeros(
|
||||
output_size_per_partition // self.pack_factor,
|
||||
scales_and_zp_size,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
}
|
||||
|
||||
if not partition_scales:
|
||||
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||
**weight_scale_args)
|
||||
|
||||
if not self.symmetric:
|
||||
qzeros = PackedColumnParameter(output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
else:
|
||||
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||
input_dim=1,
|
||||
**weight_scale_args)
|
||||
if not self.symmetric:
|
||||
qzeros = PackedvLLMParameter(input_dim=1,
|
||||
output_dim=0,
|
||||
packed_dim=0,
|
||||
packed_factor=self.pack_factor,
|
||||
**zeros_args)
|
||||
|
||||
# A 2D array defining the original shape of the weights
|
||||
# before packing
|
||||
weight_shape = BasevLLMParameter(data=torch.empty(2,
|
||||
dtype=torch.int64),
|
||||
weight_loader=weight_loader)
|
||||
|
||||
layer.register_parameter("weight_packed", weight)
|
||||
layer.register_parameter("weight_scale", weight_scale)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
|
||||
if not self.symmetric:
|
||||
layer.register_parameter("weight_zero_point", qzeros)
|
||||
|
||||
# group index (for activation reordering)
|
||||
if self.has_g_idx:
|
||||
weight_g_idx = RowvLLMParameter(data=torch.empty(
|
||||
input_size_per_partition,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
weight_loader=weight_loader)
|
||||
layer.register_parameter("weight_g_idx", weight_g_idx)
|
||||
|
||||
# Checkpoints are serialized in compressed-tensors format, which is
|
||||
# different from the format the kernel may want. Handle repacking here.
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# Default names since marlin requires empty parameters for these,
|
||||
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||
self.w_q_name = "weight_packed"
|
||||
self.w_s_name = "weight_scale"
|
||||
self.w_zp_name = "weight_zero_point"
|
||||
self.w_gidx_name = "weight_g_idx"
|
||||
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.kernel_config
|
||||
|
||||
check_marlin_supports_shape(
|
||||
c.partition_weight_shape[1], # out_features
|
||||
c.partition_weight_shape[0], # in_features
|
||||
c.full_weight_shape[0], # in_features
|
||||
c.group_size,
|
||||
)
|
||||
|
||||
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
|
||||
# Allocate marlin workspace.
|
||||
self.workspace = marlin_make_workspace(device)
|
||||
|
||||
def _transform_param(
|
||||
layer: torch.nn.Module, name: Optional[str], fn: Callable
|
||||
) -> None:
|
||||
if name is not None and getattr(layer, name, None) is not None:
|
||||
|
||||
old_param = getattr(layer, name)
|
||||
new_param = fn(old_param)
|
||||
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
||||
# compatibility
|
||||
replace_parameter(
|
||||
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
|
||||
)
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
x.data = gptq_marlin_repack(
|
||||
x.data.contiguous(),
|
||||
perm=layer.g_idx_sort_indices,
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||
x.data = marlin_permute_scales(
|
||||
x.data.contiguous(),
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
)
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name)
|
||||
)
|
||||
_transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
grouped_k = (
|
||||
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
|
||||
)
|
||||
_transform_param(
|
||||
layer,
|
||||
self.w_zp_name,
|
||||
lambda x: marlin_zero_points(
|
||||
unpack_cols(
|
||||
x.t(),
|
||||
c.weight_type.size_bits,
|
||||
grouped_k,
|
||||
c.partition_weight_shape[1],
|
||||
),
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
),
|
||||
)
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
_transform_param(layer, self.w_q_name, transform_w_q)
|
||||
_transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
c = self.kernel_config
|
||||
|
||||
def _get_weight_params(
|
||||
layer: torch.nn.Module,
|
||||
) -> tuple[
|
||||
torch.Tensor, # w_q
|
||||
torch.Tensor, # w_s
|
||||
Optional[torch.Tensor], # w_zp,
|
||||
Optional[torch.Tensor], # w_gidx
|
||||
]:
|
||||
return (
|
||||
getattr(layer, self.w_q_name),
|
||||
getattr(layer, self.w_s_name),
|
||||
getattr(layer, self.w_zp_name or "", None),
|
||||
getattr(layer, self.w_gidx_name or "", None),
|
||||
)
|
||||
|
||||
w_q, w_s, w_zp, w_gidx = _get_weight_params(layer)
|
||||
|
||||
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||
# None for marlin
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=w_q,
|
||||
weight_scale=w_s,
|
||||
weight_zp=w_zp, # type: ignore
|
||||
g_idx=w_gidx, # type: ignore
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=self.workspace,
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
is_k_full=self.is_k_full,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -4,6 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import numpy
|
||||
@@ -57,6 +58,17 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
||||
USE_FP32_REDUCE_DEFAULT = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarlinLinearLayerConfig:
|
||||
full_weight_shape: tuple[int, int] # [in, out]
|
||||
partition_weight_shape: tuple[int, int]
|
||||
weight_type: ScalarType
|
||||
act_type: torch.dtype
|
||||
group_size: int
|
||||
zero_points: bool
|
||||
has_g_idx: bool
|
||||
|
||||
|
||||
# For binary size and compile time, we don't support the same types for with and
|
||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||
|
||||
Reference in New Issue
Block a user