Replace sglang.srt.layers.quantization.scalar_types with sgl_kernel.scalar_type (#8951)
This commit is contained in:
@@ -29,9 +29,8 @@ from sglang.srt.layers.quantization.marlin_utils import (
|
|||||||
verify_marlin_supported,
|
verify_marlin_supported,
|
||||||
verify_marlin_supports_shape,
|
verify_marlin_supports_shape,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.scalar_type import scalar_types
|
|
||||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||||
from sglang.srt.layers.quantization.utils import replace_parameter
|
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.moe.topk import TopKOutput
|
from sglang.srt.layers.moe.topk import TopKOutput
|
||||||
@@ -52,6 +51,7 @@ _is_cuda = is_cuda()
|
|||||||
_is_hip = is_hip()
|
_is_hip = is_hip()
|
||||||
if _is_cuda:
|
if _is_cuda:
|
||||||
from sgl_kernel import awq_dequantize, fused_marlin_moe
|
from sgl_kernel import awq_dequantize, fused_marlin_moe
|
||||||
|
|
||||||
elif _is_hip:
|
elif _is_hip:
|
||||||
from sglang.srt.layers.quantization.awq_triton import (
|
from sglang.srt.layers.quantization.awq_triton import (
|
||||||
awq_dequantize_triton as awq_dequantize,
|
awq_dequantize_triton as awq_dequantize,
|
||||||
@@ -64,6 +64,9 @@ else:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
ScalarType, scalar_types = get_scalar_types()
|
||||||
|
|
||||||
|
|
||||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
||||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu
|
|||||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||||
from sglang.srt.layers.quantization.utils import (
|
from sglang.srt.layers.quantization.utils import (
|
||||||
all_close_1d,
|
all_close_1d,
|
||||||
cpu_has_amx_support,
|
|
||||||
per_tensor_dequantize,
|
per_tensor_dequantize,
|
||||||
replace_parameter,
|
replace_parameter,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -36,9 +36,9 @@ from sglang.srt.layers.quantization.marlin_utils import (
|
|||||||
marlin_zero_points,
|
marlin_zero_points,
|
||||||
verify_marlin_supported,
|
verify_marlin_supported,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
|
||||||
from sglang.srt.layers.quantization.utils import (
|
from sglang.srt.layers.quantization.utils import (
|
||||||
get_linear_quant_method,
|
get_linear_quant_method,
|
||||||
|
get_scalar_types,
|
||||||
replace_parameter,
|
replace_parameter,
|
||||||
unpack_cols,
|
unpack_cols,
|
||||||
)
|
)
|
||||||
@@ -60,6 +60,7 @@ if _is_cuda:
|
|||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
ScalarType, scalar_types = get_scalar_types()
|
||||||
|
|
||||||
|
|
||||||
def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
|
def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
|
||||||
|
|||||||
@@ -19,8 +19,11 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
LinearMethodBase,
|
LinearMethodBase,
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
from sglang.srt.layers.quantization.utils import (
|
||||||
from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
|
get_scalar_types,
|
||||||
|
pack_cols,
|
||||||
|
unpack_cols,
|
||||||
|
)
|
||||||
from sglang.srt.utils import get_device_capability
|
from sglang.srt.utils import get_device_capability
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -33,6 +36,8 @@ except ImportError:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ScalarType, scalar_types = get_scalar_types()
|
||||||
|
|
||||||
GPTQ_MARLIN_TILE = 16
|
GPTQ_MARLIN_TILE = 16
|
||||||
GPTQ_MARLIN_MIN_THREAD_N = 64
|
GPTQ_MARLIN_MIN_THREAD_N = 64
|
||||||
GPTQ_MARLIN_MIN_THREAD_K = 128
|
GPTQ_MARLIN_MIN_THREAD_K = 128
|
||||||
|
|||||||
@@ -1,352 +0,0 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
||||||
|
|
||||||
import functools
|
|
||||||
import struct
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
_SCALAR_TYPES_ID_MAP = {}
|
|
||||||
|
|
||||||
|
|
||||||
# Mirrors enum in `core/scalar_type.hpp`
|
|
||||||
class NanRepr(Enum):
|
|
||||||
NONE = 0 # nans are not supported
|
|
||||||
IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s
|
|
||||||
EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s
|
|
||||||
|
|
||||||
|
|
||||||
# This ScalarType class is a parallel implementation of the C++ ScalarType
|
|
||||||
# class found in csrc/core/scalar_type.hpp. These two classes should be kept
|
|
||||||
# in sync until the inductor fully supports custom C++ classes.
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class ScalarType:
|
|
||||||
"""
|
|
||||||
ScalarType can represent a wide range of floating point and integer
|
|
||||||
types, in particular it can be used to represent sub-byte data types
|
|
||||||
(something that torch.dtype currently does not support). It is also
|
|
||||||
capable of representing types with a bias, i.e.:
|
|
||||||
`stored_value = value + bias`,
|
|
||||||
this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias
|
|
||||||
of 8). The implementation for this class can be found in
|
|
||||||
csrc/core/scalar_type.hpp, these type signatures should be kept in sync
|
|
||||||
with that file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
exponent: int
|
|
||||||
"""
|
|
||||||
Number of bits in the exponent if this is a floating point type
|
|
||||||
(zero if this an integer type)
|
|
||||||
"""
|
|
||||||
|
|
||||||
mantissa: int
|
|
||||||
"""
|
|
||||||
Number of bits in the mantissa if this is a floating point type,
|
|
||||||
or the number bits representing an integer excluding the sign bit if
|
|
||||||
this an integer type.
|
|
||||||
"""
|
|
||||||
|
|
||||||
signed: bool
|
|
||||||
"If the type is signed (i.e. has a sign bit)"
|
|
||||||
|
|
||||||
bias: int
|
|
||||||
"""
|
|
||||||
bias used to encode the values in this scalar type
|
|
||||||
(value = stored_value - bias, default 0) for example if we store the
|
|
||||||
type as an unsigned integer with a bias of 128 then the value 0 will be
|
|
||||||
stored as 128 and -1 will be stored as 127 and 1 will be stored as 129.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_finite_values_only: bool = False
|
|
||||||
"""
|
|
||||||
Private: if infs are supported, used `has_infs()` instead.
|
|
||||||
"""
|
|
||||||
|
|
||||||
nan_repr: NanRepr = NanRepr.IEEE_754
|
|
||||||
"""
|
|
||||||
How NaNs are represent in this scalar type, returns NanRepr value.
|
|
||||||
(not applicable for integer types)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _floating_point_max_int(self) -> int:
|
|
||||||
assert (
|
|
||||||
self.mantissa <= 52 and self.exponent <= 11
|
|
||||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
|
||||||
|
|
||||||
max_mantissa = (1 << self.mantissa) - 1
|
|
||||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN:
|
|
||||||
max_mantissa = max_mantissa - 1
|
|
||||||
|
|
||||||
max_exponent = (1 << self.exponent) - 2
|
|
||||||
if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE:
|
|
||||||
assert (
|
|
||||||
self.exponent < 11
|
|
||||||
), f"Cannot represent max/min as a double for type {self.__str__()}"
|
|
||||||
max_exponent = max_exponent + 1
|
|
||||||
|
|
||||||
# adjust the exponent to match that of a double
|
|
||||||
# for now we assume the exponent bias is the standard 2^(e-1) -1, (where
|
|
||||||
# e is the exponent bits), there is some precedent for non-standard
|
|
||||||
# biases, example `float8_e4m3b11fnuz` here:
|
|
||||||
# https://github.com/jax-ml/ml_dtypes but to avoid premature over
|
|
||||||
# complication we are just assuming the standard exponent bias until
|
|
||||||
# there is a need to support non-standard biases
|
|
||||||
exponent_bias = (1 << (self.exponent - 1)) - 1
|
|
||||||
exponent_bias_double = (1 << 10) - 1 # double e = 11
|
|
||||||
|
|
||||||
max_exponent_double = max_exponent - exponent_bias + exponent_bias_double
|
|
||||||
|
|
||||||
# shift the mantissa and exponent into the proper positions for an
|
|
||||||
# IEEE double and bitwise-or them together.
|
|
||||||
return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52)
|
|
||||||
|
|
||||||
def _floating_point_max(self) -> float:
|
|
||||||
double_raw = self._floating_point_max_int()
|
|
||||||
return struct.unpack("!d", struct.pack("!Q", double_raw))[0]
|
|
||||||
|
|
||||||
def _raw_max(self) -> Union[int, float]:
|
|
||||||
if self.is_floating_point():
|
|
||||||
return self._floating_point_max()
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
self.size_bits < 64 or self.size_bits == 64 and self.is_signed()
|
|
||||||
), "Cannot represent max as an int"
|
|
||||||
return (1 << self.mantissa) - 1
|
|
||||||
|
|
||||||
def _raw_min(self) -> Union[int, float]:
|
|
||||||
if self.is_floating_point():
|
|
||||||
assert (
|
|
||||||
self.is_signed()
|
|
||||||
), "We currently assume all floating point types are signed"
|
|
||||||
sign_bit_double = 1 << 63
|
|
||||||
|
|
||||||
max_raw = self._floating_point_max_int()
|
|
||||||
min_raw = max_raw | sign_bit_double
|
|
||||||
return struct.unpack("!d", struct.pack("!Q", min_raw))[0]
|
|
||||||
else:
|
|
||||||
assert (
|
|
||||||
not self.is_signed() or self.size_bits <= 64
|
|
||||||
), "Cannot represent min as a int64_t"
|
|
||||||
|
|
||||||
if self.is_signed():
|
|
||||||
return -(1 << (self.size_bits - 1))
|
|
||||||
else:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
@functools.cached_property
|
|
||||||
def id(self) -> int:
|
|
||||||
"""
|
|
||||||
Convert the ScalarType to an int which can be passed to pytorch custom
|
|
||||||
ops. This layout of the int must be kept in sync with the C++
|
|
||||||
ScalarType's from_id method.
|
|
||||||
"""
|
|
||||||
val = 0
|
|
||||||
offset = 0
|
|
||||||
|
|
||||||
def or_and_advance(member, bit_width):
|
|
||||||
nonlocal val
|
|
||||||
nonlocal offset
|
|
||||||
bit_mask = (1 << bit_width) - 1
|
|
||||||
val = val | (int(member) & bit_mask) << offset
|
|
||||||
offset = offset + bit_width
|
|
||||||
|
|
||||||
or_and_advance(self.exponent, 8)
|
|
||||||
or_and_advance(self.mantissa, 8)
|
|
||||||
or_and_advance(self.signed, 1)
|
|
||||||
or_and_advance(self.bias, 32)
|
|
||||||
or_and_advance(self._finite_values_only, 1)
|
|
||||||
or_and_advance(self.nan_repr.value, 8)
|
|
||||||
|
|
||||||
assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64"
|
|
||||||
|
|
||||||
_SCALAR_TYPES_ID_MAP[val] = self
|
|
||||||
|
|
||||||
return val
|
|
||||||
|
|
||||||
@property
|
|
||||||
def size_bits(self) -> int:
|
|
||||||
return self.exponent + self.mantissa + int(self.signed)
|
|
||||||
|
|
||||||
def min(self) -> Union[int, float]:
|
|
||||||
"""
|
|
||||||
Min representable value for this scalar type.
|
|
||||||
(accounting for bias if there is one)
|
|
||||||
"""
|
|
||||||
return self._raw_min() - self.bias
|
|
||||||
|
|
||||||
def max(self) -> Union[int, float]:
|
|
||||||
"""
|
|
||||||
Max representable value for this scalar type.
|
|
||||||
(accounting for bias if there is one)
|
|
||||||
"""
|
|
||||||
return self._raw_max() - self.bias
|
|
||||||
|
|
||||||
def is_signed(self) -> bool:
|
|
||||||
"""
|
|
||||||
If the type is signed (i.e. has a sign bit), same as `signed`
|
|
||||||
added for consistency with:
|
|
||||||
https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html
|
|
||||||
"""
|
|
||||||
return self.signed
|
|
||||||
|
|
||||||
def is_floating_point(self) -> bool:
|
|
||||||
"If the type is a floating point type"
|
|
||||||
return self.exponent != 0
|
|
||||||
|
|
||||||
def is_integer(self) -> bool:
|
|
||||||
"If the type is an integer type"
|
|
||||||
return self.exponent == 0
|
|
||||||
|
|
||||||
def has_bias(self) -> bool:
|
|
||||||
"If the type has a non-zero bias"
|
|
||||||
return self.bias != 0
|
|
||||||
|
|
||||||
def has_infs(self) -> bool:
|
|
||||||
"If the type is floating point and supports infinity"
|
|
||||||
return not self._finite_values_only
|
|
||||||
|
|
||||||
def has_nans(self) -> bool:
|
|
||||||
return self.nan_repr != NanRepr.NONE.value
|
|
||||||
|
|
||||||
def is_ieee_754(self) -> bool:
|
|
||||||
"""
|
|
||||||
If the type is a floating point type that follows IEEE 754
|
|
||||||
conventions
|
|
||||||
"""
|
|
||||||
return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only
|
|
||||||
|
|
||||||
def __str__(self) -> str:
|
|
||||||
"""
|
|
||||||
naming generally follows: https://github.com/jax-ml/ml_dtypes
|
|
||||||
for floating point types (leading f) the scheme is:
|
|
||||||
`float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
|
||||||
flags:
|
|
||||||
- no-flags: means it follows IEEE 754 conventions
|
|
||||||
- f: means finite values only (no infinities)
|
|
||||||
- n: means nans are supported (non-standard encoding)
|
|
||||||
for integer types the scheme is:
|
|
||||||
`[u]int<size_bits>[b<bias>]`
|
|
||||||
- if bias is not present it means its zero
|
|
||||||
"""
|
|
||||||
if self.is_floating_point():
|
|
||||||
ret = (
|
|
||||||
"float"
|
|
||||||
+ str(self.size_bits)
|
|
||||||
+ "_e"
|
|
||||||
+ str(self.exponent)
|
|
||||||
+ "m"
|
|
||||||
+ str(self.mantissa)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not self.is_ieee_754():
|
|
||||||
if self._finite_values_only:
|
|
||||||
ret = ret + "f"
|
|
||||||
if self.nan_repr != NanRepr.NONE:
|
|
||||||
ret = ret + "n"
|
|
||||||
|
|
||||||
return ret
|
|
||||||
else:
|
|
||||||
ret = ("int" if self.is_signed() else "uint") + str(self.size_bits)
|
|
||||||
if self.has_bias():
|
|
||||||
ret = ret + "b" + str(self.bias)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return "ScalarType." + self.__str__()
|
|
||||||
|
|
||||||
# __len__ needs to be defined (and has to throw TypeError) for pytorch's
|
|
||||||
# opcheck to work.
|
|
||||||
def __len__(self) -> int:
|
|
||||||
raise TypeError
|
|
||||||
|
|
||||||
#
|
|
||||||
# Convenience Constructors
|
|
||||||
#
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
|
||||||
"Create a signed integer scalar type (size_bits includes sign-bit)."
|
|
||||||
ret = cls(0, size_bits - 1, True, bias if bias else 0)
|
|
||||||
ret.id # noqa B018: make sure the id is cached
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType":
|
|
||||||
"""Create a unsigned integer scalar type."""
|
|
||||||
ret = cls(0, size_bits, False, bias if bias else 0)
|
|
||||||
ret.id # noqa B018: make sure the id is cached
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType":
|
|
||||||
"""
|
|
||||||
Create a standard floating point type
|
|
||||||
(i.e. follows IEEE 754 conventions).
|
|
||||||
"""
|
|
||||||
assert mantissa > 0 and exponent > 0
|
|
||||||
ret = cls(exponent, mantissa, True, 0)
|
|
||||||
ret.id # noqa B018: make sure the id is cached
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def float_(
|
|
||||||
cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr
|
|
||||||
) -> "ScalarType":
|
|
||||||
"""
|
|
||||||
Create a non-standard floating point type
|
|
||||||
(i.e. does not follow IEEE 754 conventions).
|
|
||||||
"""
|
|
||||||
assert mantissa > 0 and exponent > 0
|
|
||||||
assert nan_repr != NanRepr.IEEE_754, (
|
|
||||||
"use `float_IEEE754` constructor for floating point types that "
|
|
||||||
"follow IEEE 754 conventions"
|
|
||||||
)
|
|
||||||
ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr)
|
|
||||||
ret.id # noqa B018: make sure the id is cached
|
|
||||||
return ret
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_id(cls, scalar_type_id: int):
|
|
||||||
if scalar_type_id not in _SCALAR_TYPES_ID_MAP:
|
|
||||||
raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.")
|
|
||||||
return _SCALAR_TYPES_ID_MAP[scalar_type_id]
|
|
||||||
|
|
||||||
|
|
||||||
# naming generally follows: https://github.com/jax-ml/ml_dtypes
|
|
||||||
# for floating point types (leading f) the scheme is:
|
|
||||||
# `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
|
||||||
# flags:
|
|
||||||
# - no-flags: means it follows IEEE 754 conventions
|
|
||||||
# - f: means finite values only (no infinities)
|
|
||||||
# - n: means nans are supported (non-standard encoding)
|
|
||||||
# for integer types the scheme is:
|
|
||||||
# `[u]int<size_bits>[b<bias>]`
|
|
||||||
# - if bias is not present it means its zero
|
|
||||||
|
|
||||||
|
|
||||||
class scalar_types:
|
|
||||||
int4 = ScalarType.int_(4, None)
|
|
||||||
uint4 = ScalarType.uint(4, None)
|
|
||||||
int8 = ScalarType.int_(8, None)
|
|
||||||
uint8 = ScalarType.uint(8, None)
|
|
||||||
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
|
|
||||||
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
|
|
||||||
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
|
|
||||||
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
|
|
||||||
|
|
||||||
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
|
||||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
|
||||||
|
|
||||||
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
|
||||||
float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
|
|
||||||
|
|
||||||
# "gptq" types
|
|
||||||
uint2b2 = ScalarType.uint(2, 2)
|
|
||||||
uint3b4 = ScalarType.uint(3, 4)
|
|
||||||
uint4b8 = ScalarType.uint(4, 8)
|
|
||||||
uint8b128 = ScalarType.uint(8, 128)
|
|
||||||
|
|
||||||
# colloquial names
|
|
||||||
bfloat16 = float16_e8m7
|
|
||||||
float16 = float16_e5m10
|
|
||||||
@@ -11,13 +11,39 @@ import numpy
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
||||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
from sglang.srt.utils import is_cuda
|
||||||
from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
|
|
||||||
|
|
||||||
|
def get_scalar_types():
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
tuple: (ScalarType, scalar_types)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from sgl_kernel.scalar_type import ScalarType, scalar_types
|
||||||
|
|
||||||
|
return ScalarType, scalar_types
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
class MockScalarType:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class MockScalarTypes:
|
||||||
|
uint4b8 = "uint4b8"
|
||||||
|
uint8b128 = "uint8b128"
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
return f"mock_{name}"
|
||||||
|
|
||||||
|
return MockScalarType, MockScalarTypes()
|
||||||
|
|
||||||
|
|
||||||
|
ScalarType, scalar_types = get_scalar_types()
|
||||||
|
|
||||||
|
|
||||||
def is_layer_skipped(
|
def is_layer_skipped(
|
||||||
prefix: str,
|
prefix: str,
|
||||||
ignored_layers: List[str],
|
ignored_layers: List[str],
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ from typing import Optional
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from sgl_kernel import fused_marlin_moe
|
from sgl_kernel import fused_marlin_moe
|
||||||
|
from sgl_kernel.scalar_type import ScalarType, scalar_types
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
|
||||||
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
|
from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -10,13 +10,13 @@ from typing import Optional
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from sgl_kernel.scalar_type import ScalarType
|
||||||
|
|
||||||
from sglang.srt.layers.quantization.marlin_utils import (
|
from sglang.srt.layers.quantization.marlin_utils import (
|
||||||
GPTQ_MARLIN_TILE,
|
GPTQ_MARLIN_TILE,
|
||||||
marlin_permute_scales,
|
marlin_permute_scales,
|
||||||
marlin_zero_points,
|
marlin_zero_points,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.quantization.scalar_type import ScalarType
|
|
||||||
from sglang.srt.layers.quantization.utils import (
|
from sglang.srt.layers.quantization.utils import (
|
||||||
get_pack_factor,
|
get_pack_factor,
|
||||||
gptq_quantize_weights,
|
gptq_quantize_weights,
|
||||||
|
|||||||
Reference in New Issue
Block a user