[6/n]decouple quantization implementation from vLLM dependency (#10750)
This commit is contained in:
@@ -10,10 +10,6 @@ import torch
|
|||||||
try:
|
try:
|
||||||
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
||||||
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
|
||||||
CompressedTensorsW8A8Fp8MoEMethod,
|
|
||||||
CompressedTensorsWNA16MoEMethod,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
||||||
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
||||||
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
||||||
@@ -175,51 +171,3 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
|||||||
return original_isinstance(obj, classinfo)
|
return original_isinstance(obj, classinfo)
|
||||||
|
|
||||||
builtins.isinstance = patched_isinstance
|
builtins.isinstance = patched_isinstance
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
||||||
"""
|
|
||||||
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
|
||||||
Convert sglang arguments to vllm arguments.
|
|
||||||
"""
|
|
||||||
original_apply = class_obj.apply
|
|
||||||
sig = inspect.signature(original_apply)
|
|
||||||
param_names = list(sig.parameters.keys())
|
|
||||||
has_correction_bias = "e_score_correction_bias" in param_names
|
|
||||||
|
|
||||||
def new_apply(
|
|
||||||
self,
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
topk_output: TopKOutput,
|
|
||||||
*,
|
|
||||||
activation: str = "silu",
|
|
||||||
apply_router_weight_on_input: bool = False,
|
|
||||||
inplace: bool = True,
|
|
||||||
no_combine: bool = False,
|
|
||||||
routed_scaling_factor: Optional[float] = None,
|
|
||||||
):
|
|
||||||
assert activation == "silu"
|
|
||||||
assert inplace and not no_combine
|
|
||||||
|
|
||||||
kwargs = {
|
|
||||||
"self": self,
|
|
||||||
"layer": layer,
|
|
||||||
"x": x,
|
|
||||||
"topk_output": topk_output,
|
|
||||||
}
|
|
||||||
return original_apply(**kwargs)
|
|
||||||
|
|
||||||
setattr(class_obj, "apply", new_apply)
|
|
||||||
|
|
||||||
|
|
||||||
def monkey_patch_quant_configs():
|
|
||||||
"""Apply all monkey patches in one place."""
|
|
||||||
|
|
||||||
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
|
||||||
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
|
||||||
|
|
||||||
|
|
||||||
# Only apply monkey patches if vllm is available
|
|
||||||
if VLLM_AVAILABLE:
|
|
||||||
monkey_patch_quant_configs()
|
|
||||||
|
|||||||
@@ -30,10 +30,12 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
|
|||||||
CompressedTensorsMoEMethod,
|
CompressedTensorsMoEMethod,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
WNA16_SUPPORTED_BITS,
|
||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
CompressedTensorsW8A8Fp8,
|
CompressedTensorsW8A8Fp8,
|
||||||
CompressedTensorsW8A8Int8,
|
CompressedTensorsW8A8Int8,
|
||||||
CompressedTensorsW8A16Fp8,
|
CompressedTensorsW8A16Fp8,
|
||||||
|
CompressedTensorsWNA16,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
||||||
find_matched_target,
|
find_matched_target,
|
||||||
@@ -43,23 +45,6 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
|||||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_24 import (
|
|
||||||
CompressedTensors24,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_w4a16_sparse24 import (
|
|
||||||
W4A16SPARSE24_SUPPORTED_BITS,
|
|
||||||
CompressedTensorsW4A16Sparse24,
|
|
||||||
)
|
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
|
|
||||||
WNA16_SUPPORTED_BITS,
|
|
||||||
CompressedTensorsWNA16,
|
|
||||||
)
|
|
||||||
|
|
||||||
VLLM_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
VLLM_AVAILABLE = False
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
__all__ = ["CompressedTensorsLinearMethod"]
|
__all__ = ["CompressedTensorsLinearMethod"]
|
||||||
@@ -380,19 +365,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
|
|
||||||
# Detect If Mixed Precision
|
# Detect If Mixed Precision
|
||||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||||
if not VLLM_AVAILABLE:
|
|
||||||
raise ImportError(
|
|
||||||
"vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm"
|
|
||||||
)
|
|
||||||
if (
|
|
||||||
self.quant_format == CompressionFormat.marlin_24.value
|
|
||||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS
|
|
||||||
):
|
|
||||||
return CompressedTensorsW4A16Sparse24(
|
|
||||||
strategy=weight_quant.strategy,
|
|
||||||
num_bits=weight_quant.num_bits,
|
|
||||||
group_size=weight_quant.group_size,
|
|
||||||
)
|
|
||||||
if (
|
if (
|
||||||
self.quant_format == CompressionFormat.pack_quantized.value
|
self.quant_format == CompressionFormat.pack_quantized.value
|
||||||
and weight_quant.num_bits in WNA16_SUPPORTED_BITS
|
and weight_quant.num_bits in WNA16_SUPPORTED_BITS
|
||||||
@@ -403,6 +375,10 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
group_size=weight_quant.group_size,
|
group_size=weight_quant.group_size,
|
||||||
actorder=weight_quant.actorder,
|
actorder=weight_quant.actorder,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise ImportError(
|
||||||
|
"Other method (CompressedTensorsW4A16Sparse24) is not supported now"
|
||||||
|
)
|
||||||
|
|
||||||
if is_activation_quantization_format(self.quant_format):
|
if is_activation_quantization_format(self.quant_format):
|
||||||
if self._is_fp8_w8a8(weight_quant, input_quant):
|
if self._is_fp8_w8a8(weight_quant, input_quant):
|
||||||
@@ -426,10 +402,6 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
|
|
||||||
# note: input_quant can be None
|
# note: input_quant can be None
|
||||||
if self._is_fp8_w8a16(weight_quant, input_quant):
|
if self._is_fp8_w8a16(weight_quant, input_quant):
|
||||||
if not VLLM_AVAILABLE:
|
|
||||||
raise ImportError(
|
|
||||||
"vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm"
|
|
||||||
)
|
|
||||||
is_static_input_scheme = input_quant and not input_quant.dynamic
|
is_static_input_scheme = input_quant and not input_quant.dynamic
|
||||||
return CompressedTensorsW8A16Fp8(
|
return CompressedTensorsW8A16Fp8(
|
||||||
strategy=weight_quant.strategy,
|
strategy=weight_quant.strategy,
|
||||||
@@ -470,7 +442,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
|
|
||||||
# Find the "target" in the compressed-tensors config
|
# Find the "target" in the compressed-tensors config
|
||||||
# that our layer conforms to.
|
# that our layer conforms to.
|
||||||
# TODO (@robertgshaw): add compressed-tensors as dep
|
# TODO : add compressed-tensors as dep
|
||||||
# so we do not have to re-write these functions
|
# so we do not have to re-write these functions
|
||||||
# need to make accelerate optional in ct to do this
|
# need to make accelerate optional in ct to do this
|
||||||
|
|
||||||
@@ -508,24 +480,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
input_quant=input_quant,
|
input_quant=input_quant,
|
||||||
sparsity_scheme=sparsity_scheme,
|
sparsity_scheme=sparsity_scheme,
|
||||||
):
|
):
|
||||||
if not VLLM_AVAILABLE:
|
raise ImportError("CompressedTensors24 is not supported now")
|
||||||
raise ImportError(
|
|
||||||
"vllm is not installed, to use CompressedTensors24, please install vllm"
|
|
||||||
)
|
|
||||||
# Have a valid sparsity scheme
|
|
||||||
# Validate layer is supported by Cutlass 2:4 Kernel
|
|
||||||
model_compression_config = (
|
|
||||||
None
|
|
||||||
if sparsity_scheme is None or sparsity_scheme.format == "dense"
|
|
||||||
else self.config
|
|
||||||
)
|
|
||||||
|
|
||||||
scheme = CompressedTensors24(
|
|
||||||
quantized=weight_quant is not None or input_quant is not None,
|
|
||||||
weight_quant=weight_quant,
|
|
||||||
input_quant=input_quant,
|
|
||||||
model_compression_config=model_compression_config,
|
|
||||||
)
|
|
||||||
elif weight_quant is None:
|
elif weight_quant is None:
|
||||||
logger.warning_once(
|
logger.warning_once(
|
||||||
"Acceleration for non-quantized schemes is "
|
"Acceleration for non-quantized schemes is "
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import enum
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sgl_kernel import fused_marlin_moe
|
from sgl_kernel import fused_marlin_moe
|
||||||
@@ -31,9 +31,13 @@ from sglang.srt.environ import envs
|
|||||||
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
from sglang.srt.layers.moe import MoeRunner, MoeRunnerBackend, MoeRunnerConfig
|
||||||
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
from sglang.srt.layers.moe.moe_runner.triton import TritonMoeQuantInfo
|
||||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||||
from sglang.srt.layers.quantization.compressed_tensors import WNA16_SUPPORTED_BITS
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
WNA16_SUPPORTED_BITS,
|
||||||
|
)
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
||||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||||
|
from sglang.srt.layers.quantization.gptq import gptq_marlin_moe_repack
|
||||||
|
from sglang.srt.layers.quantization.marlin_utils import marlin_moe_permute_scales
|
||||||
from sglang.srt.layers.quantization.utils import (
|
from sglang.srt.layers.quantization.utils import (
|
||||||
all_close_1d,
|
all_close_1d,
|
||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
@@ -42,6 +46,7 @@ from sglang.srt.layers.quantization.utils import (
|
|||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_bool_env_var,
|
get_bool_env_var,
|
||||||
get_compiler_backend,
|
get_compiler_backend,
|
||||||
|
is_cuda,
|
||||||
is_hip,
|
is_hip,
|
||||||
set_weight_attrs,
|
set_weight_attrs,
|
||||||
)
|
)
|
||||||
@@ -57,6 +62,8 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
||||||
|
|
||||||
if _use_aiter:
|
if _use_aiter:
|
||||||
@@ -64,12 +71,9 @@ if _use_aiter:
|
|||||||
|
|
||||||
from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
|
from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
|
||||||
|
|
||||||
try:
|
|
||||||
import vllm # noqa: F401
|
|
||||||
|
|
||||||
VLLM_AVAILABLE = True
|
if _is_cuda:
|
||||||
except ImportError:
|
from sgl_kernel import fused_marlin_moe
|
||||||
VLLM_AVAILABLE = False
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -127,10 +131,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
|||||||
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
|
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
|
||||||
input_quant = quant_config.target_scheme_map["Linear"].get("input_activations")
|
input_quant = quant_config.target_scheme_map["Linear"].get("input_activations")
|
||||||
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
||||||
if not VLLM_AVAILABLE:
|
|
||||||
raise ImportError(
|
logger.info_once("Using CompressedTensorsWNA16MarlinMoEMethod")
|
||||||
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
|
|
||||||
)
|
|
||||||
return CompressedTensorsWNA16MoEMethod(quant_config)
|
return CompressedTensorsWNA16MoEMethod(quant_config)
|
||||||
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
||||||
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
||||||
@@ -432,9 +434,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
):
|
):
|
||||||
if self.num_gpu_experts != -1:
|
if self.num_gpu_experts != -1:
|
||||||
num_experts = self.num_gpu_experts
|
num_experts = self.num_gpu_experts
|
||||||
# assert (
|
|
||||||
# params_dtype == torch.float16
|
|
||||||
# ), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
|
|
||||||
|
|
||||||
# Will transpose the loaded weight along the
|
# Will transpose the loaded weight along the
|
||||||
# intermediate and hidden dim sizes. Will
|
# intermediate and hidden dim sizes. Will
|
||||||
@@ -573,44 +572,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
getattr(layer, name).copy_(new_t)
|
getattr(layer, name).copy_(new_t)
|
||||||
del new_t
|
del new_t
|
||||||
|
|
||||||
def get_scale_perms(num_bits: int):
|
|
||||||
scale_perm: List[int] = []
|
|
||||||
for i in range(8):
|
|
||||||
scale_perm.extend([i + 8 * j for j in range(8)])
|
|
||||||
scale_perm_single: List[int] = []
|
|
||||||
for i in range(4):
|
|
||||||
scale_perm_single.extend(
|
|
||||||
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]
|
|
||||||
)
|
|
||||||
return scale_perm, scale_perm_single
|
|
||||||
|
|
||||||
def marlin_permute_scales(
|
|
||||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
|
|
||||||
):
|
|
||||||
scale_perm, scale_perm_single = get_scale_perms(num_bits)
|
|
||||||
if group_size < size_k and group_size != -1:
|
|
||||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
|
||||||
else:
|
|
||||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
|
||||||
s = s.reshape((-1, size_n)).contiguous()
|
|
||||||
return s
|
|
||||||
|
|
||||||
def marlin_moe_permute_scales(
|
|
||||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
|
|
||||||
):
|
|
||||||
num_experts = s.shape[0]
|
|
||||||
output = torch.empty(
|
|
||||||
(num_experts, s.shape[1], s.shape[2]), device=s.device, dtype=s.dtype
|
|
||||||
)
|
|
||||||
for e in range(num_experts):
|
|
||||||
output[e] = marlin_permute_scales(
|
|
||||||
s[e], size_k, size_n, group_size, num_bits
|
|
||||||
)
|
|
||||||
return output
|
|
||||||
|
|
||||||
size_k2 = layer.w2_weight_packed.shape[2]
|
|
||||||
size_k13 = layer.w13_weight_packed.shape[2]
|
|
||||||
|
|
||||||
num_experts = layer.w13_weight_g_idx.shape[0]
|
num_experts = layer.w13_weight_g_idx.shape[0]
|
||||||
device = layer.w13_weight_g_idx.device
|
device = layer.w13_weight_g_idx.device
|
||||||
|
|
||||||
@@ -657,42 +618,39 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
requires_grad=False,
|
requires_grad=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
from vllm import _custom_ops as vllm_ops
|
marlin_w13_qweight = gptq_marlin_moe_repack(
|
||||||
|
|
||||||
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
|
|
||||||
layer.w13_weight_packed,
|
layer.w13_weight_packed,
|
||||||
layer.w13_g_idx_sort_indices,
|
layer.w13_g_idx_sort_indices,
|
||||||
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
||||||
layer.w13_weight_packed.shape[2],
|
layer.w13_weight_packed.shape[2],
|
||||||
self.num_bits,
|
self.num_bits,
|
||||||
)
|
)
|
||||||
replace_tensor("w13_weight_packed", marlin_w13_qweight)
|
replace_parameter(layer, "w13_weight_packed", marlin_w13_qweight)
|
||||||
marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
|
marlin_w2_qweight = gptq_marlin_moe_repack(
|
||||||
layer.w2_weight_packed,
|
layer.w2_weight_packed,
|
||||||
layer.w2_g_idx_sort_indices,
|
layer.w2_g_idx_sort_indices,
|
||||||
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
||||||
layer.w2_weight_packed.shape[2],
|
layer.w2_weight_packed.shape[2],
|
||||||
self.num_bits,
|
self.num_bits,
|
||||||
)
|
)
|
||||||
replace_tensor("w2_weight_packed", marlin_w2_qweight)
|
replace_parameter(layer, "w2_weight_packed", marlin_w2_qweight)
|
||||||
# Repack scales
|
# Repack scales
|
||||||
marlin_w13_scales = marlin_moe_permute_scales(
|
marlin_w13_scales = marlin_moe_permute_scales(
|
||||||
layer.w13_weight_scale,
|
layer.w13_weight_scale,
|
||||||
size_k13,
|
layer.w13_weight_packed.shape[2],
|
||||||
layer.w13_weight_scale.shape[2],
|
layer.w13_weight_scale.shape[2],
|
||||||
self.group_size,
|
self.group_size,
|
||||||
self.num_bits,
|
|
||||||
)
|
)
|
||||||
replace_tensor("w13_weight_scale", marlin_w13_scales)
|
replace_parameter(layer, "w13_weight_scale", marlin_w13_scales)
|
||||||
|
|
||||||
marlin_w2_scales = marlin_moe_permute_scales(
|
marlin_w2_scales = marlin_moe_permute_scales(
|
||||||
layer.w2_weight_scale,
|
layer.w2_weight_scale,
|
||||||
layer.w2_weight_scale.shape[1]
|
layer.w2_weight_scale.shape[1]
|
||||||
* (self.group_size if self.group_size != -1 else self.packed_factor),
|
* (self.group_size if self.group_size != -1 else self.packed_factor),
|
||||||
size_k2,
|
layer.w2_weight_scale.shape[2],
|
||||||
self.group_size,
|
self.group_size,
|
||||||
self.num_bits,
|
|
||||||
)
|
)
|
||||||
replace_tensor("w2_weight_scale", marlin_w2_scales)
|
replace_parameter(layer, "w2_weight_scale", marlin_w2_scales)
|
||||||
|
|
||||||
def create_moe_runner(
|
def create_moe_runner(
|
||||||
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
|
||||||
@@ -716,7 +674,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|||||||
|
|
||||||
topk_weights, topk_ids, router_logits = topk_output
|
topk_weights, topk_ids, router_logits = topk_output
|
||||||
|
|
||||||
output = torch.ops.vllm.fused_marlin_moe(
|
output = fused_marlin_moe(
|
||||||
x,
|
x,
|
||||||
layer.w13_weight_packed,
|
layer.w13_weight_packed,
|
||||||
layer.w2_weight_packed,
|
layer.w2_weight_packed,
|
||||||
|
|||||||
@@ -4,10 +4,13 @@ from .compressed_tensors_scheme import CompressedTensorsScheme
|
|||||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||||
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
||||||
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
||||||
|
from .compressed_tensors_wNa16 import WNA16_SUPPORTED_BITS, CompressedTensorsWNA16
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CompressedTensorsScheme",
|
"CompressedTensorsScheme",
|
||||||
"CompressedTensorsW8A8Fp8",
|
"CompressedTensorsW8A8Fp8",
|
||||||
"CompressedTensorsW8A16Fp8",
|
"CompressedTensorsW8A16Fp8",
|
||||||
"CompressedTensorsW8A8Int8",
|
"CompressedTensorsW8A8Int8",
|
||||||
|
"CompressedTensorsWNA16",
|
||||||
|
"WNA16_SUPPORTED_BITS",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -14,24 +14,11 @@ from sglang.srt.layers.parameter import (
|
|||||||
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||||
CompressedTensorsScheme,
|
CompressedTensorsScheme,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.utils import convert_to_channelwise
|
from sglang.srt.layers.quantization.marlin_utils_fp8 import (
|
||||||
|
|
||||||
try:
|
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
|
||||||
apply_fp8_marlin_linear,
|
apply_fp8_marlin_linear,
|
||||||
prepare_fp8_layer_for_marlin,
|
prepare_fp8_layer_for_marlin,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.quantization.utils import convert_to_channelwise
|
||||||
MARLIN_FP8_AVAILABLE = True
|
|
||||||
except ImportError:
|
|
||||||
MARLIN_FP8_AVAILABLE = False
|
|
||||||
|
|
||||||
def apply_fp8_marlin_linear(*args, **kwargs):
|
|
||||||
raise ImportError("vllm is not installed")
|
|
||||||
|
|
||||||
def prepare_fp8_layer_for_marlin(*args, **kwargs):
|
|
||||||
raise ImportError("vllm is not installed")
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["CompressedTensorsW8A16Fp8"]
|
__all__ = ["CompressedTensorsW8A16Fp8"]
|
||||||
|
|
||||||
@@ -43,11 +30,6 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
|||||||
self.strategy = strategy
|
self.strategy = strategy
|
||||||
self.is_static_input_scheme = is_static_input_scheme
|
self.is_static_input_scheme = is_static_input_scheme
|
||||||
|
|
||||||
if not MARLIN_FP8_AVAILABLE:
|
|
||||||
raise ImportError(
|
|
||||||
"vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
# ampere and up
|
# ampere and up
|
||||||
|
|||||||
@@ -0,0 +1,339 @@
|
|||||||
|
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from compressed_tensors.quantization import ActivationOrdering
|
||||||
|
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from sglang.srt.layers.parameter import (
|
||||||
|
BasevLLMParameter,
|
||||||
|
ChannelQuantScaleParameter,
|
||||||
|
GroupQuantScaleParameter,
|
||||||
|
PackedColumnParameter,
|
||||||
|
PackedvLLMParameter,
|
||||||
|
RowvLLMParameter,
|
||||||
|
permute_param_layout_,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
||||||
|
CompressedTensorsScheme,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.marlin_utils import (
|
||||||
|
MarlinLinearLayerConfig,
|
||||||
|
apply_gptq_marlin_linear,
|
||||||
|
check_marlin_supports_shape,
|
||||||
|
marlin_is_k_full,
|
||||||
|
marlin_make_empty_g_idx,
|
||||||
|
marlin_make_workspace,
|
||||||
|
marlin_permute_scales,
|
||||||
|
marlin_repeat_scales_on_all_ranks,
|
||||||
|
marlin_sort_g_idx,
|
||||||
|
marlin_zero_points,
|
||||||
|
)
|
||||||
|
from sglang.srt.layers.quantization.utils import (
|
||||||
|
get_scalar_types,
|
||||||
|
replace_parameter,
|
||||||
|
unpack_cols,
|
||||||
|
)
|
||||||
|
from sglang.srt.utils import is_cuda
|
||||||
|
|
||||||
|
_is_cuda = is_cuda()
|
||||||
|
|
||||||
|
if _is_cuda:
|
||||||
|
from sgl_kernel import gptq_marlin_repack
|
||||||
|
|
||||||
|
|
||||||
|
ScalarType, scalar_types = get_scalar_types()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
__all__ = ["CompressedTensorsWNA16"]
|
||||||
|
WNA16_SUPPORTED_TYPES_MAP = {
|
||||||
|
4: scalar_types.uint4b8,
|
||||||
|
8: scalar_types.uint8b128
|
||||||
|
}
|
||||||
|
WNA16_ZP_SUPPORTED_TYPES_MAP = {4: scalar_types.uint4, 8: scalar_types.uint8}
|
||||||
|
WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys())
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedTensorsWNA16(CompressedTensorsScheme):
|
||||||
|
_kernel_backends_being_used: set[str] = set()
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
strategy: str,
|
||||||
|
num_bits: int,
|
||||||
|
group_size: Optional[int] = None,
|
||||||
|
symmetric: Optional[bool] = True,
|
||||||
|
actorder: Optional[ActivationOrdering] = None):
|
||||||
|
|
||||||
|
self.pack_factor = 32 // num_bits
|
||||||
|
self.strategy = strategy
|
||||||
|
self.symmetric = symmetric
|
||||||
|
self.group_size = -1 if group_size is None else group_size
|
||||||
|
self.has_g_idx = actorder == ActivationOrdering.GROUP
|
||||||
|
|
||||||
|
if self.group_size == -1 and self.strategy != "channel":
|
||||||
|
raise ValueError("Marlin kernels require group quantization or "
|
||||||
|
"channelwise quantization, but found no group "
|
||||||
|
"size and strategy is not channelwise.")
|
||||||
|
|
||||||
|
if num_bits not in WNA16_SUPPORTED_TYPES_MAP:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported num_bits = {num_bits}. "
|
||||||
|
f"Supported num_bits = {WNA16_SUPPORTED_TYPES_MAP.keys()}")
|
||||||
|
|
||||||
|
self.quant_type = (WNA16_ZP_SUPPORTED_TYPES_MAP[num_bits]
|
||||||
|
if not self.symmetric else
|
||||||
|
WNA16_SUPPORTED_TYPES_MAP[num_bits])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_min_capability(cls) -> int:
|
||||||
|
# ampere and up
|
||||||
|
return 80
|
||||||
|
|
||||||
|
def create_weights(self, layer: torch.nn.Module, output_size: int,
|
||||||
|
input_size: int, output_partition_sizes: list[int],
|
||||||
|
input_size_per_partition: int,
|
||||||
|
params_dtype: torch.dtype, weight_loader: Callable,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
output_size_per_partition = sum(output_partition_sizes)
|
||||||
|
|
||||||
|
self.kernel_config = MarlinLinearLayerConfig(
|
||||||
|
full_weight_shape=(input_size, output_size),
|
||||||
|
partition_weight_shape=(
|
||||||
|
input_size_per_partition,
|
||||||
|
output_size_per_partition,
|
||||||
|
),
|
||||||
|
weight_type=self.quant_type,
|
||||||
|
act_type=params_dtype,
|
||||||
|
group_size=self.group_size,
|
||||||
|
zero_points=not self.symmetric,
|
||||||
|
has_g_idx=self.has_g_idx
|
||||||
|
)
|
||||||
|
|
||||||
|
# If group_size is -1, we are in channelwise case.
|
||||||
|
group_size = self.group_size if self.group_size != -1 else input_size
|
||||||
|
row_parallel = (input_size != input_size_per_partition)
|
||||||
|
partition_scales = not marlin_repeat_scales_on_all_ranks(
|
||||||
|
self.has_g_idx, self.group_size, row_parallel)
|
||||||
|
|
||||||
|
scales_and_zp_size = input_size // group_size
|
||||||
|
|
||||||
|
if partition_scales:
|
||||||
|
assert input_size_per_partition % group_size == 0
|
||||||
|
scales_and_zp_size = input_size_per_partition // group_size
|
||||||
|
|
||||||
|
weight = PackedvLLMParameter(input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
weight_loader=weight_loader,
|
||||||
|
packed_factor=self.pack_factor,
|
||||||
|
packed_dim=1,
|
||||||
|
data=torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
input_size_per_partition //
|
||||||
|
self.pack_factor,
|
||||||
|
dtype=torch.int32,
|
||||||
|
))
|
||||||
|
|
||||||
|
weight_scale_args = {
|
||||||
|
"weight_loader":
|
||||||
|
weight_loader,
|
||||||
|
"data":
|
||||||
|
torch.empty(
|
||||||
|
output_size_per_partition,
|
||||||
|
scales_and_zp_size,
|
||||||
|
dtype=params_dtype,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
zeros_args = {
|
||||||
|
"weight_loader":
|
||||||
|
weight_loader,
|
||||||
|
"data":
|
||||||
|
torch.zeros(
|
||||||
|
output_size_per_partition // self.pack_factor,
|
||||||
|
scales_and_zp_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if not partition_scales:
|
||||||
|
weight_scale = ChannelQuantScaleParameter(output_dim=0,
|
||||||
|
**weight_scale_args)
|
||||||
|
|
||||||
|
if not self.symmetric:
|
||||||
|
qzeros = PackedColumnParameter(output_dim=0,
|
||||||
|
packed_dim=0,
|
||||||
|
packed_factor=self.pack_factor,
|
||||||
|
**zeros_args)
|
||||||
|
else:
|
||||||
|
weight_scale = GroupQuantScaleParameter(output_dim=0,
|
||||||
|
input_dim=1,
|
||||||
|
**weight_scale_args)
|
||||||
|
if not self.symmetric:
|
||||||
|
qzeros = PackedvLLMParameter(input_dim=1,
|
||||||
|
output_dim=0,
|
||||||
|
packed_dim=0,
|
||||||
|
packed_factor=self.pack_factor,
|
||||||
|
**zeros_args)
|
||||||
|
|
||||||
|
# A 2D array defining the original shape of the weights
|
||||||
|
# before packing
|
||||||
|
weight_shape = BasevLLMParameter(data=torch.empty(2,
|
||||||
|
dtype=torch.int64),
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
|
||||||
|
layer.register_parameter("weight_packed", weight)
|
||||||
|
layer.register_parameter("weight_scale", weight_scale)
|
||||||
|
layer.register_parameter("weight_shape", weight_shape)
|
||||||
|
|
||||||
|
if not self.symmetric:
|
||||||
|
layer.register_parameter("weight_zero_point", qzeros)
|
||||||
|
|
||||||
|
# group index (for activation reordering)
|
||||||
|
if self.has_g_idx:
|
||||||
|
weight_g_idx = RowvLLMParameter(data=torch.empty(
|
||||||
|
input_size_per_partition,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
input_dim=0,
|
||||||
|
weight_loader=weight_loader)
|
||||||
|
layer.register_parameter("weight_g_idx", weight_g_idx)
|
||||||
|
|
||||||
|
# Checkpoints are serialized in compressed-tensors format, which is
|
||||||
|
# different from the format the kernel may want. Handle repacking here.
|
||||||
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
|
# Default names since marlin requires empty parameters for these,
|
||||||
|
# TODO: remove this requirement from marlin (allow optional tensors)
|
||||||
|
self.w_q_name = "weight_packed"
|
||||||
|
self.w_s_name = "weight_scale"
|
||||||
|
self.w_zp_name = "weight_zero_point"
|
||||||
|
self.w_gidx_name = "weight_g_idx"
|
||||||
|
|
||||||
|
device = getattr(layer, self.w_q_name).device
|
||||||
|
c = self.kernel_config
|
||||||
|
|
||||||
|
check_marlin_supports_shape(
|
||||||
|
c.partition_weight_shape[1], # out_features
|
||||||
|
c.partition_weight_shape[0], # in_features
|
||||||
|
c.full_weight_shape[0], # in_features
|
||||||
|
c.group_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
||||||
|
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||||
|
|
||||||
|
# Allocate marlin workspace.
|
||||||
|
self.workspace = marlin_make_workspace(device)
|
||||||
|
|
||||||
|
def _transform_param(
|
||||||
|
layer: torch.nn.Module, name: Optional[str], fn: Callable
|
||||||
|
) -> None:
|
||||||
|
if name is not None and getattr(layer, name, None) is not None:
|
||||||
|
|
||||||
|
old_param = getattr(layer, name)
|
||||||
|
new_param = fn(old_param)
|
||||||
|
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
||||||
|
# compatibility
|
||||||
|
replace_parameter(
|
||||||
|
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
def transform_w_q(x):
|
||||||
|
assert isinstance(x, BasevLLMParameter)
|
||||||
|
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||||
|
x.data = gptq_marlin_repack(
|
||||||
|
x.data.contiguous(),
|
||||||
|
perm=layer.g_idx_sort_indices,
|
||||||
|
size_k=c.partition_weight_shape[0],
|
||||||
|
size_n=c.partition_weight_shape[1],
|
||||||
|
num_bits=c.weight_type.size_bits,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def transform_w_s(x):
|
||||||
|
assert isinstance(x, BasevLLMParameter)
|
||||||
|
permute_param_layout_(x, input_dim=0, output_dim=1)
|
||||||
|
x.data = marlin_permute_scales(
|
||||||
|
x.data.contiguous(),
|
||||||
|
size_k=c.partition_weight_shape[0],
|
||||||
|
size_n=c.partition_weight_shape[1],
|
||||||
|
group_size=c.group_size,
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
if c.has_g_idx:
|
||||||
|
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||||
|
getattr(layer, self.w_gidx_name)
|
||||||
|
)
|
||||||
|
_transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||||
|
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||||
|
else:
|
||||||
|
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||||
|
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||||
|
|
||||||
|
if c.zero_points:
|
||||||
|
grouped_k = (
|
||||||
|
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
|
||||||
|
)
|
||||||
|
_transform_param(
|
||||||
|
layer,
|
||||||
|
self.w_zp_name,
|
||||||
|
lambda x: marlin_zero_points(
|
||||||
|
unpack_cols(
|
||||||
|
x.t(),
|
||||||
|
c.weight_type.size_bits,
|
||||||
|
grouped_k,
|
||||||
|
c.partition_weight_shape[1],
|
||||||
|
),
|
||||||
|
size_k=grouped_k,
|
||||||
|
size_n=c.partition_weight_shape[1],
|
||||||
|
num_bits=c.weight_type.size_bits,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||||
|
_transform_param(layer, self.w_q_name, transform_w_q)
|
||||||
|
_transform_param(layer, self.w_s_name, transform_w_s)
|
||||||
|
|
||||||
|
def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||||
|
c = self.kernel_config
|
||||||
|
|
||||||
|
def _get_weight_params(
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
) -> tuple[
|
||||||
|
torch.Tensor, # w_q
|
||||||
|
torch.Tensor, # w_s
|
||||||
|
Optional[torch.Tensor], # w_zp,
|
||||||
|
Optional[torch.Tensor], # w_gidx
|
||||||
|
]:
|
||||||
|
return (
|
||||||
|
getattr(layer, self.w_q_name),
|
||||||
|
getattr(layer, self.w_s_name),
|
||||||
|
getattr(layer, self.w_zp_name or "", None),
|
||||||
|
getattr(layer, self.w_gidx_name or "", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
w_q, w_s, w_zp, w_gidx = _get_weight_params(layer)
|
||||||
|
|
||||||
|
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||||
|
# None for marlin
|
||||||
|
return apply_gptq_marlin_linear(
|
||||||
|
input=x,
|
||||||
|
weight=w_q,
|
||||||
|
weight_scale=w_s,
|
||||||
|
weight_zp=w_zp, # type: ignore
|
||||||
|
g_idx=w_gidx, # type: ignore
|
||||||
|
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||||
|
workspace=self.workspace,
|
||||||
|
wtype=c.weight_type,
|
||||||
|
input_size_per_partition=c.partition_weight_shape[0],
|
||||||
|
output_size_per_partition=c.partition_weight_shape[1],
|
||||||
|
is_k_full=self.is_k_full,
|
||||||
|
bias=bias,
|
||||||
|
)
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
@@ -57,6 +58,17 @@ MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
|||||||
USE_FP32_REDUCE_DEFAULT = True
|
USE_FP32_REDUCE_DEFAULT = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MarlinLinearLayerConfig:
|
||||||
|
full_weight_shape: tuple[int, int] # [in, out]
|
||||||
|
partition_weight_shape: tuple[int, int]
|
||||||
|
weight_type: ScalarType
|
||||||
|
act_type: torch.dtype
|
||||||
|
group_size: int
|
||||||
|
zero_points: bool
|
||||||
|
has_g_idx: bool
|
||||||
|
|
||||||
|
|
||||||
# For binary size and compile time, we don't support the same types for with and
|
# For binary size and compile time, we don't support the same types for with and
|
||||||
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
||||||
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
||||||
|
|||||||
Reference in New Issue
Block a user