From a669bc2f74eee618c28c6d3db0ddf74db9ac2d92 Mon Sep 17 00:00:00 2001 From: Hongbo Xu <1320612015@qq.com> Date: Thu, 14 Aug 2025 10:41:41 +0800 Subject: [PATCH] Replace `sglang.srt.layers.quantization.scalar_types` with `sgl_kernel.scalar_type` (#8951) --- python/sglang/srt/layers/quantization/awq.py | 7 +- .../compressed_tensors_moe.py | 1 - python/sglang/srt/layers/quantization/gptq.py | 3 +- .../srt/layers/quantization/marlin_utils.py | 9 +- .../srt/layers/quantization/scalar_type.py | 352 ------------------ .../sglang/srt/layers/quantization/utils.py | 30 +- python/sglang/test/test_marlin_moe.py | 2 +- python/sglang/test/test_marlin_utils.py | 2 +- 8 files changed, 44 insertions(+), 362 deletions(-) delete mode 100644 python/sglang/srt/layers/quantization/scalar_type.py diff --git a/python/sglang/srt/layers/quantization/awq.py b/python/sglang/srt/layers/quantization/awq.py index 0f66b954c..f5111df74 100644 --- a/python/sglang/srt/layers/quantization/awq.py +++ b/python/sglang/srt/layers/quantization/awq.py @@ -29,9 +29,8 @@ from sglang.srt.layers.quantization.marlin_utils import ( verify_marlin_supported, verify_marlin_supports_shape, ) -from sglang.srt.layers.quantization.scalar_type import scalar_types from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod -from sglang.srt.layers.quantization.utils import replace_parameter +from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter if TYPE_CHECKING: from sglang.srt.layers.moe.topk import TopKOutput @@ -52,6 +51,7 @@ _is_cuda = is_cuda() _is_hip = is_hip() if _is_cuda: from sgl_kernel import awq_dequantize, fused_marlin_moe + elif _is_hip: from sglang.srt.layers.quantization.awq_triton import ( awq_dequantize_triton as awq_dequantize, @@ -64,6 +64,9 @@ else: logger = logging.getLogger(__name__) +ScalarType, scalar_types = get_scalar_types() + + def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]): return any(module_name in prefix for module_name in modules_to_not_convert) diff --git a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c6da7e149..c2e908f8c 100644 --- a/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -16,7 +16,6 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_qu from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.layers.quantization.utils import ( all_close_1d, - cpu_has_amx_support, per_tensor_dequantize, replace_parameter, ) diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index 4f2eba4e3..fd510dd24 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -36,9 +36,9 @@ from sglang.srt.layers.quantization.marlin_utils import ( marlin_zero_points, verify_marlin_supported, ) -from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types from sglang.srt.layers.quantization.utils import ( get_linear_quant_method, + get_scalar_types, replace_parameter, unpack_cols, ) @@ -60,6 +60,7 @@ if _is_cuda: logger = logging.getLogger(__name__) +ScalarType, scalar_types = get_scalar_types() def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool: diff --git a/python/sglang/srt/layers/quantization/marlin_utils.py b/python/sglang/srt/layers/quantization/marlin_utils.py index 1edc672ab..1873d3970 100644 --- a/python/sglang/srt/layers/quantization/marlin_utils.py +++ b/python/sglang/srt/layers/quantization/marlin_utils.py @@ -19,8 +19,11 @@ from sglang.srt.layers.quantization.base_config import ( LinearMethodBase, QuantizationConfig, ) -from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types -from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols +from sglang.srt.layers.quantization.utils import ( + get_scalar_types, + pack_cols, + unpack_cols, +) from sglang.srt.utils import get_device_capability if TYPE_CHECKING: @@ -33,6 +36,8 @@ except ImportError: logger = logging.getLogger(__name__) +ScalarType, scalar_types = get_scalar_types() + GPTQ_MARLIN_TILE = 16 GPTQ_MARLIN_MIN_THREAD_N = 64 GPTQ_MARLIN_MIN_THREAD_K = 128 diff --git a/python/sglang/srt/layers/quantization/scalar_type.py b/python/sglang/srt/layers/quantization/scalar_type.py deleted file mode 100644 index 5aeb88651..000000000 --- a/python/sglang/srt/layers/quantization/scalar_type.py +++ /dev/null @@ -1,352 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import functools -import struct -from dataclasses import dataclass -from enum import Enum -from typing import Optional, Union - -_SCALAR_TYPES_ID_MAP = {} - - -# Mirrors enum in `core/scalar_type.hpp` -class NanRepr(Enum): - NONE = 0 # nans are not supported - IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s - EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s - - -# This ScalarType class is a parallel implementation of the C++ ScalarType -# class found in csrc/core/scalar_type.hpp. These two classes should be kept -# in sync until the inductor fully supports custom C++ classes. -@dataclass(frozen=True) -class ScalarType: - """ - ScalarType can represent a wide range of floating point and integer - types, in particular it can be used to represent sub-byte data types - (something that torch.dtype currently does not support). It is also - capable of representing types with a bias, i.e.: - `stored_value = value + bias`, - this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias - of 8). The implementation for this class can be found in - csrc/core/scalar_type.hpp, these type signatures should be kept in sync - with that file. - """ - - exponent: int - """ - Number of bits in the exponent if this is a floating point type - (zero if this an integer type) - """ - - mantissa: int - """ - Number of bits in the mantissa if this is a floating point type, - or the number bits representing an integer excluding the sign bit if - this an integer type. - """ - - signed: bool - "If the type is signed (i.e. has a sign bit)" - - bias: int - """ - bias used to encode the values in this scalar type - (value = stored_value - bias, default 0) for example if we store the - type as an unsigned integer with a bias of 128 then the value 0 will be - stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. - """ - - _finite_values_only: bool = False - """ - Private: if infs are supported, used `has_infs()` instead. - """ - - nan_repr: NanRepr = NanRepr.IEEE_754 - """ - How NaNs are represent in this scalar type, returns NanRepr value. - (not applicable for integer types) - """ - - def _floating_point_max_int(self) -> int: - assert ( - self.mantissa <= 52 and self.exponent <= 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - - max_mantissa = (1 << self.mantissa) - 1 - if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: - max_mantissa = max_mantissa - 1 - - max_exponent = (1 << self.exponent) - 2 - if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE: - assert ( - self.exponent < 11 - ), f"Cannot represent max/min as a double for type {self.__str__()}" - max_exponent = max_exponent + 1 - - # adjust the exponent to match that of a double - # for now we assume the exponent bias is the standard 2^(e-1) -1, (where - # e is the exponent bits), there is some precedent for non-standard - # biases, example `float8_e4m3b11fnuz` here: - # https://github.com/jax-ml/ml_dtypes but to avoid premature over - # complication we are just assuming the standard exponent bias until - # there is a need to support non-standard biases - exponent_bias = (1 << (self.exponent - 1)) - 1 - exponent_bias_double = (1 << 10) - 1 # double e = 11 - - max_exponent_double = max_exponent - exponent_bias + exponent_bias_double - - # shift the mantissa and exponent into the proper positions for an - # IEEE double and bitwise-or them together. - return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52) - - def _floating_point_max(self) -> float: - double_raw = self._floating_point_max_int() - return struct.unpack("!d", struct.pack("!Q", double_raw))[0] - - def _raw_max(self) -> Union[int, float]: - if self.is_floating_point(): - return self._floating_point_max() - else: - assert ( - self.size_bits < 64 or self.size_bits == 64 and self.is_signed() - ), "Cannot represent max as an int" - return (1 << self.mantissa) - 1 - - def _raw_min(self) -> Union[int, float]: - if self.is_floating_point(): - assert ( - self.is_signed() - ), "We currently assume all floating point types are signed" - sign_bit_double = 1 << 63 - - max_raw = self._floating_point_max_int() - min_raw = max_raw | sign_bit_double - return struct.unpack("!d", struct.pack("!Q", min_raw))[0] - else: - assert ( - not self.is_signed() or self.size_bits <= 64 - ), "Cannot represent min as a int64_t" - - if self.is_signed(): - return -(1 << (self.size_bits - 1)) - else: - return 0 - - @functools.cached_property - def id(self) -> int: - """ - Convert the ScalarType to an int which can be passed to pytorch custom - ops. This layout of the int must be kept in sync with the C++ - ScalarType's from_id method. - """ - val = 0 - offset = 0 - - def or_and_advance(member, bit_width): - nonlocal val - nonlocal offset - bit_mask = (1 << bit_width) - 1 - val = val | (int(member) & bit_mask) << offset - offset = offset + bit_width - - or_and_advance(self.exponent, 8) - or_and_advance(self.mantissa, 8) - or_and_advance(self.signed, 1) - or_and_advance(self.bias, 32) - or_and_advance(self._finite_values_only, 1) - or_and_advance(self.nan_repr.value, 8) - - assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64" - - _SCALAR_TYPES_ID_MAP[val] = self - - return val - - @property - def size_bits(self) -> int: - return self.exponent + self.mantissa + int(self.signed) - - def min(self) -> Union[int, float]: - """ - Min representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_min() - self.bias - - def max(self) -> Union[int, float]: - """ - Max representable value for this scalar type. - (accounting for bias if there is one) - """ - return self._raw_max() - self.bias - - def is_signed(self) -> bool: - """ - If the type is signed (i.e. has a sign bit), same as `signed` - added for consistency with: - https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html - """ - return self.signed - - def is_floating_point(self) -> bool: - "If the type is a floating point type" - return self.exponent != 0 - - def is_integer(self) -> bool: - "If the type is an integer type" - return self.exponent == 0 - - def has_bias(self) -> bool: - "If the type has a non-zero bias" - return self.bias != 0 - - def has_infs(self) -> bool: - "If the type is floating point and supports infinity" - return not self._finite_values_only - - def has_nans(self) -> bool: - return self.nan_repr != NanRepr.NONE.value - - def is_ieee_754(self) -> bool: - """ - If the type is a floating point type that follows IEEE 754 - conventions - """ - return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only - - def __str__(self) -> str: - """ - naming generally follows: https://github.com/jax-ml/ml_dtypes - for floating point types (leading f) the scheme is: - `float_em[flags]` - flags: - - no-flags: means it follows IEEE 754 conventions - - f: means finite values only (no infinities) - - n: means nans are supported (non-standard encoding) - for integer types the scheme is: - `[u]int[b]` - - if bias is not present it means its zero - """ - if self.is_floating_point(): - ret = ( - "float" - + str(self.size_bits) - + "_e" - + str(self.exponent) - + "m" - + str(self.mantissa) - ) - - if not self.is_ieee_754(): - if self._finite_values_only: - ret = ret + "f" - if self.nan_repr != NanRepr.NONE: - ret = ret + "n" - - return ret - else: - ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) - if self.has_bias(): - ret = ret + "b" + str(self.bias) - return ret - - def __repr__(self) -> str: - return "ScalarType." + self.__str__() - - # __len__ needs to be defined (and has to throw TypeError) for pytorch's - # opcheck to work. - def __len__(self) -> int: - raise TypeError - - # - # Convenience Constructors - # - - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": - "Create a signed integer scalar type (size_bits includes sign-bit)." - ret = cls(0, size_bits - 1, True, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": - """Create a unsigned integer scalar type.""" - ret = cls(0, size_bits, False, bias if bias else 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType": - """ - Create a standard floating point type - (i.e. follows IEEE 754 conventions). - """ - assert mantissa > 0 and exponent > 0 - ret = cls(exponent, mantissa, True, 0) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def float_( - cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr - ) -> "ScalarType": - """ - Create a non-standard floating point type - (i.e. does not follow IEEE 754 conventions). - """ - assert mantissa > 0 and exponent > 0 - assert nan_repr != NanRepr.IEEE_754, ( - "use `float_IEEE754` constructor for floating point types that " - "follow IEEE 754 conventions" - ) - ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) - ret.id # noqa B018: make sure the id is cached - return ret - - @classmethod - def from_id(cls, scalar_type_id: int): - if scalar_type_id not in _SCALAR_TYPES_ID_MAP: - raise ValueError(f"scalar_type_id {scalar_type_id} doesn't exists.") - return _SCALAR_TYPES_ID_MAP[scalar_type_id] - - -# naming generally follows: https://github.com/jax-ml/ml_dtypes -# for floating point types (leading f) the scheme is: -# `float_em[flags]` -# flags: -# - no-flags: means it follows IEEE 754 conventions -# - f: means finite values only (no infinities) -# - n: means nans are supported (non-standard encoding) -# for integer types the scheme is: -# `[u]int[b]` -# - if bias is not present it means its zero - - -class scalar_types: - int4 = ScalarType.int_(4, None) - uint4 = ScalarType.uint(4, None) - int8 = ScalarType.int_(8, None) - uint8 = ScalarType.uint(8, None) - float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) - float8_e5m2 = ScalarType.float_IEEE754(5, 2) - float16_e8m7 = ScalarType.float_IEEE754(8, 7) - float16_e5m10 = ScalarType.float_IEEE754(5, 10) - - # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main - float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) - - # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf - float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE) - - # "gptq" types - uint2b2 = ScalarType.uint(2, 2) - uint3b4 = ScalarType.uint(3, 4) - uint4b8 = ScalarType.uint(4, 8) - uint8b128 = ScalarType.uint(8, 128) - - # colloquial names - bfloat16 = float16_e8m7 - float16 = float16_e5m10 diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py index 9b19e0309..85d3d8933 100644 --- a/python/sglang/srt/layers/quantization/utils.py +++ b/python/sglang/srt/layers/quantization/utils.py @@ -11,13 +11,39 @@ import numpy import torch from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant -from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types -from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu +from sglang.srt.utils import is_cuda if TYPE_CHECKING: from sglang.srt.layers.quantization.base_config import QuantizationConfig +def get_scalar_types(): + """ + Returns: + tuple: (ScalarType, scalar_types) + """ + try: + from sgl_kernel.scalar_type import ScalarType, scalar_types + + return ScalarType, scalar_types + except ImportError: + + class MockScalarType: + pass + + class MockScalarTypes: + uint4b8 = "uint4b8" + uint8b128 = "uint8b128" + + def __getattr__(self, name): + return f"mock_{name}" + + return MockScalarType, MockScalarTypes() + + +ScalarType, scalar_types = get_scalar_types() + + def is_layer_skipped( prefix: str, ignored_layers: List[str], diff --git a/python/sglang/test/test_marlin_moe.py b/python/sglang/test/test_marlin_moe.py index e5b4c986a..77b0109df 100644 --- a/python/sglang/test/test_marlin_moe.py +++ b/python/sglang/test/test_marlin_moe.py @@ -4,9 +4,9 @@ from typing import Optional import pytest import torch from sgl_kernel import fused_marlin_moe +from sgl_kernel.scalar_type import ScalarType, scalar_types from sglang.srt.layers.activation import SiluAndMul -from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types from sglang.test.test_marlin_utils import awq_marlin_quantize, marlin_quantize diff --git a/python/sglang/test/test_marlin_utils.py b/python/sglang/test/test_marlin_utils.py index 920cb7d8b..0c0590077 100644 --- a/python/sglang/test/test_marlin_utils.py +++ b/python/sglang/test/test_marlin_utils.py @@ -10,13 +10,13 @@ from typing import Optional import numpy as np import torch +from sgl_kernel.scalar_type import ScalarType from sglang.srt.layers.quantization.marlin_utils import ( GPTQ_MARLIN_TILE, marlin_permute_scales, marlin_zero_points, ) -from sglang.srt.layers.quantization.scalar_type import ScalarType from sglang.srt.layers.quantization.utils import ( get_pack_factor, gptq_quantize_weights,