From 9b81f9bd349cb483cbf16196f749762462af0eac Mon Sep 17 00:00:00 2001 From: Xiaoyu Zhang <35585791+BBuf@users.noreply.github.com> Date: Tue, 18 Mar 2025 06:51:59 +0800 Subject: [PATCH] sglang quant module remove vllm dependency (#4507) --- .../srt/layers/quantization/__init__.py | 297 +++++++----- .../srt/layers/quantization/blockwise_int8.py | 2 +- python/sglang/srt/layers/quantization/fp8.py | 91 ++-- .../srt/layers/quantization/fp8_utils.py | 172 +++---- python/sglang/srt/layers/quantization/gptq.py | 27 +- .../srt/layers/quantization/kv_cache.py | 98 ++++ .../srt/layers/quantization/modelopt_quant.py | 16 +- .../sglang/srt/layers/quantization/utils.py | 442 ++++++++++++++++++ 8 files changed, 907 insertions(+), 238 deletions(-) create mode 100644 python/sglang/srt/layers/quantization/kv_cache.py create mode 100644 python/sglang/srt/layers/quantization/utils.py diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index bc1bedee1..8de731420 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -6,21 +6,41 @@ from copy import deepcopy from typing import Callable, Dict, Optional, Type, Union import torch -from vllm.model_executor.layers.quantization.aqlm import AQLMConfig -from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig -from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( - CompressedTensorsConfig, -) -from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig -from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config -from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config -from vllm.model_executor.layers.quantization.gguf import GGUFConfig -from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config -from vllm.model_executor.layers.quantization.marlin import MarlinConfig -from vllm.model_executor.layers.quantization.qqq import QQQConfig -from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig + +try: + from vllm.model_executor.layers.quantization.aqlm import AQLMConfig + from vllm.model_executor.layers.quantization.awq import AWQConfig + from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig + from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig + from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( + CompressedTensorsConfig, + ) + from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig + from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config + from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config + from vllm.model_executor.layers.quantization.gguf import GGUFConfig + from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( + GPTQMarlin24Config, + ) + from vllm.model_executor.layers.quantization.marlin import MarlinConfig + from vllm.model_executor.layers.quantization.qqq import QQQConfig + from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig + + VLLM_AVAILABLE = True +except ImportError: + VLLM_AVAILABLE = False + + # Define empty classes as placeholders when vllm is not available + class DummyConfig: + pass + + AQLMConfig = AWQConfig = AWQMarlinConfig = BitsAndBytesConfig = ( + CompressedTensorsConfig + ) = DummyConfig + DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = ( + GPTQMarlin24Config + ) = DummyConfig + MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config @@ -30,29 +50,37 @@ from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config -QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { - "aqlm": AQLMConfig, - "awq": AWQConfig, - "deepspeedfp": DeepSpeedFPConfig, - "tpu_int8": Int8TpuConfig, +# Base quantization methods that don't depend on vllm +BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "fp8": Fp8Config, "blockwise_int8": BlockInt8Config, - "fbgemm_fp8": FBGEMMFp8Config, - "marlin": MarlinConfig, "modelopt": ModelOptFp8Config, - "gguf": GGUFConfig, - "gptq_marlin_24": GPTQMarlin24Config, "gptq_marlin": GPTQMarlinConfig, - "awq_marlin": AWQMarlinConfig, "gptq": GPTQConfig, - "compressed-tensors": CompressedTensorsConfig, - "bitsandbytes": BitsAndBytesConfig, - "qqq": QQQConfig, - "experts_int8": ExpertsInt8Config, "w8a8_int8": W8A8Int8Config, "w8a8_fp8": W8A8Fp8Config, } +# Add vllm-dependent methods if available +QUANTIZATION_METHODS = BASE_QUANTIZATION_METHODS.copy() +if VLLM_AVAILABLE: + VLLM_QUANTIZATION_METHODS = { + "aqlm": AQLMConfig, + "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, + "tpu_int8": Int8TpuConfig, + "fbgemm_fp8": FBGEMMFp8Config, + "marlin": MarlinConfig, + "gguf": GGUFConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "awq_marlin": AWQMarlinConfig, + "compressed-tensors": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, + "qqq": QQQConfig, + "experts_int8": ExpertsInt8Config, + } + QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS) + def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: if quantization not in QUANTIZATION_METHODS: @@ -157,25 +185,31 @@ def get_linear_quant_method( def gptq_get_quant_method(self, layer, prefix): - from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod - from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinLinearMethod, - GPTQMarlinMoEMethod, - ) + if not VLLM_AVAILABLE: + return None - from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE - - if isinstance(layer, FusedMoE): - return GPTQMarlinMoEMethod(self) - - if isinstance(self, GPTQConfig): - return get_linear_quant_method( - self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod - ) - elif isinstance(self, GPTQMarlinConfig): - return get_linear_quant_method( - self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod + try: + from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinLinearMethod, + GPTQMarlinMoEMethod, ) + + from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE + + if isinstance(layer, FusedMoE): + return GPTQMarlinMoEMethod(self) + + if isinstance(self, GPTQConfig): + return get_linear_quant_method( + self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod + ) + elif isinstance(self, GPTQMarlinConfig): + return get_linear_quant_method( + self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod + ) + except ImportError: + pass return None @@ -187,33 +221,40 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False): Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig can recognize sglang layers """ + if not VLLM_AVAILABLE: + return if reverse: builtins.isinstance = original_isinstance return - from vllm.model_executor.layers.fused_moe import FusedMoE - from vllm.model_executor.layers.linear import LinearBase - from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, - ) + try: + from vllm.model_executor.layers.fused_moe import FusedMoE + from vllm.model_executor.layers.linear import LinearBase + from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, + ) - from sglang.srt.layers.linear import LinearBase as PatchedLinearBase - from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE - from sglang.srt.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding as PatchedVocabParallelEmbedding, - ) + from sglang.srt.layers.linear import LinearBase as PatchedLinearBase + from sglang.srt.layers.moe.fused_moe_triton.layer import ( + FusedMoE as PatchedFusedMoE, + ) + from sglang.srt.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding as PatchedVocabParallelEmbedding, + ) - def patched_isinstance(obj, classinfo): - if classinfo is LinearBase: - return original_isinstance(obj, PatchedLinearBase) - if classinfo is FusedMoE: - return original_isinstance(obj, PatchedFusedMoE) - if classinfo is VocabParallelEmbedding: - return original_isinstance(obj, PatchedVocabParallelEmbedding) - return original_isinstance(obj, classinfo) + def patched_isinstance(obj, classinfo): + if classinfo is LinearBase: + return original_isinstance(obj, PatchedLinearBase) + if classinfo is FusedMoE: + return original_isinstance(obj, PatchedFusedMoE) + if classinfo is VocabParallelEmbedding: + return original_isinstance(obj, PatchedVocabParallelEmbedding) + return original_isinstance(obj, classinfo) - builtins.isinstance = patched_isinstance + builtins.isinstance = patched_isinstance + except ImportError: + return def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): @@ -221,72 +262,88 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"): Monkey patch the apply function of vllm's FusedMoEMethodBase. Convert sglang arguments to vllm arguments. """ - original_apply = class_obj.apply - sig = inspect.signature(original_apply) - param_names = list(sig.parameters.keys()) - has_correction_bias = "e_score_correction_bias" in param_names + if not VLLM_AVAILABLE: + return - def new_apply( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - correction_bias: Optional[torch.Tensor] = None, - activation: str = "silu", - inplace: bool = True, - no_combine: bool = False, - ): - assert activation == "silu" - assert inplace and not no_combine + try: + original_apply = class_obj.apply + sig = inspect.signature(original_apply) + param_names = list(sig.parameters.keys()) + has_correction_bias = "e_score_correction_bias" in param_names - kwargs = { - "self": self, - "layer": layer, - "x": x, - "router_logits": router_logits, - "top_k": top_k, - "renormalize": renormalize, - "use_grouped_topk": use_grouped_topk, - "topk_group": topk_group, - "num_expert_group": num_expert_group, - "custom_routing_function": custom_routing_function, - } - if correction_bias is not None: - if not has_correction_bias: - raise ValueError( - "Please increase the version of your vllm. Try `pip install vllm==0.7.2`" - ) - kwargs["e_score_correction_bias"] = correction_bias - return original_apply(**kwargs) + def new_apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + inplace: bool = True, + no_combine: bool = False, + ): + assert activation == "silu" + assert inplace and not no_combine - setattr(class_obj, "apply", new_apply) + kwargs = { + "self": self, + "layer": layer, + "x": x, + "router_logits": router_logits, + "top_k": top_k, + "renormalize": renormalize, + "use_grouped_topk": use_grouped_topk, + "topk_group": topk_group, + "num_expert_group": num_expert_group, + "custom_routing_function": custom_routing_function, + } + if correction_bias is not None: + if not has_correction_bias: + raise ValueError( + "Please increase the version of your vllm. Try `pip install vllm==0.7.2`" + ) + kwargs["e_score_correction_bias"] = correction_bias + return original_apply(**kwargs) + + setattr(class_obj, "apply", new_apply) + except (ImportError, AttributeError): + return def monkey_patch_quant_configs(): """Apply all monkey patches in one place.""" - from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod - from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( - CompressedTensorsW8A8Fp8MoEMethod, - CompressedTensorsWNA16MoEMethod, - ) - from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinMoEMethod + if not VLLM_AVAILABLE: + return - setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) - setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method) + try: + from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod + from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( + CompressedTensorsW8A8Fp8MoEMethod, + CompressedTensorsWNA16MoEMethod, + ) + from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinMoEMethod, + ) - monkey_patch_moe_apply(AWQMoEMethod) - monkey_patch_moe_apply(GPTQMarlinMoEMethod) - monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod) - monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod) + setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method) + setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method) + + monkey_patch_moe_apply(AWQMoEMethod) + monkey_patch_moe_apply(GPTQMarlinMoEMethod) + monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod) + monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod) + except ImportError: + return -monkey_patch_quant_configs() +# Only apply monkey patches if vllm is available +if VLLM_AVAILABLE: + monkey_patch_quant_configs() __all__ = [ diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py index ce526cd6a..a5f15c92b 100644 --- a/python/sglang/srt/layers/quantization/blockwise_int8.py +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional import torch from torch.nn import Module -from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.linear import ( @@ -19,6 +18,7 @@ from sglang.srt.layers.quantization.base_config import ( QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear +from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import set_weight_attrs ACTIVATION_SCHEMES = ["static", "dynamic"] diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index bff0fe96e..2934fc5ac 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -7,20 +7,33 @@ import torch import torch.nn.functional as F from torch.nn import Module from torch.nn.parameter import Parameter -from vllm import _custom_ops as ops -from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( - apply_fp8_marlin_linear, - prepare_fp8_layer_for_marlin, -) -from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + +from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.utils import ( all_close_1d, convert_to_channelwise, + is_layer_skipped, per_tensor_dequantize, requantize_with_max_scale, ) +try: + from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( + apply_fp8_marlin_linear, + prepare_fp8_layer_for_marlin, + ) + + MARLIN_FP8_AVAILABLE = True +except ImportError: + MARLIN_FP8_AVAILABLE = False + + def apply_fp8_marlin_linear(*args, **kwargs): + raise ImportError("vllm is not installed") + + def prepare_fp8_layer_for_marlin(*args, **kwargs): + raise ImportError("vllm is not installed") + + from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.linear import ( LinearBase, @@ -46,6 +59,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( ) from sglang.srt.utils import ( get_bool_env_var, + is_cuda, is_hip, permute_weight, print_warning_once, @@ -60,6 +74,13 @@ if _is_hip: from aiter.fused_moe_bf16_asm import asm_moe from aiter.ops.shuffle import shuffle_weight +_is_cuda = is_cuda() + +if _is_cuda: + from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant +else: + from vllm import _custom_ops as vllm_ops + logger = logging.getLogger(__name__) @@ -173,7 +194,9 @@ class Fp8LinearMethod(LinearMethodBase): # For GPUs that lack FP8 hardware support, we can leverage the Marlin # kernel for fast weight-only FP8 quantization - self.use_marlin = get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") + self.use_marlin = ( + get_bool_env_var("SGLANG_FORCE_FP8_MARLIN") and MARLIN_FP8_AVAILABLE + ) # Disable marlin for ROCm if _is_hip: self.use_marlin = False @@ -371,9 +394,12 @@ class Fp8LinearMethod(LinearMethodBase): ) if self.use_marlin: - prepare_fp8_layer_for_marlin(layer) - # Activations not quantized for marlin. - del layer.input_scale + try: + prepare_fp8_layer_for_marlin(layer) + # Activations not quantized for marlin. + del layer.input_scale + except ImportError: + self.use_marlin = False def apply( self, @@ -383,15 +409,18 @@ class Fp8LinearMethod(LinearMethodBase): ) -> torch.Tensor: if self.use_marlin: - return apply_fp8_marlin_linear( - input=x, - weight=layer.weight, - weight_scale=layer.weight_scale, - workspace=layer.workspace, - size_n=layer.output_size_per_partition, - size_k=layer.input_size_per_partition, - bias=bias, - ) + try: + return apply_fp8_marlin_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + workspace=layer.workspace, + size_n=layer.output_size_per_partition, + size_k=layer.input_size_per_partition, + bias=bias, + ) + except ImportError: + self.use_marlin = False if self.block_quant: return apply_w8a8_block_fp8_linear( @@ -680,12 +709,20 @@ class Fp8MoEMethod: requires_grad=False, ) for expert in range(layer.num_experts): - w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) - ) - w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( - ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) - ) + if _is_cuda: + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) + else: + w13_weight[expert, :, :], layer.w13_weight_scale[expert] = ( + vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :]) + ) + w2_weight[expert, :, :], layer.w2_weight_scale[expert] = ( + vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :]) + ) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index 42c8c1371..bc3813e48 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -28,7 +28,12 @@ if _is_cuda: from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_quant_fp8 if use_vllm_cutlass_w8a8_fp8_kernel: - from vllm import _custom_ops as ops + try: + from vllm import _custom_ops as ops + + VLLM_AVAILABLE = True + except ImportError: + VLLM_AVAILABLE = False else: from sgl_kernel import fp8_scaled_mm @@ -219,90 +224,97 @@ def apply_fp8_linear( ) if cutlass_fp8_supported: - if use_vllm_cutlass_w8a8_fp8_kernel: - # Fall back to vllm cutlass w8a8 fp8 kernel - output = ops.cutlass_scaled_mm( - qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - ) - else: - assert ( - weight_scale.numel() == weight.shape[1] - ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale" - output = fp8_scaled_mm( - qinput, weight, x_scale, weight_scale, out_dtype=input.dtype, bias=bias - ) - return output.view(*output_shape) + try: + if VLLM_AVAILABLE and use_vllm_cutlass_w8a8_fp8_kernel: + # Fall back to vllm cutlass w8a8 fp8 kernel + output = ops.cutlass_scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) + else: + assert ( + weight_scale.numel() == weight.shape[1] + ), "cutlass w8a8 fp8 sgl-kernel only supports per-channel scale" + output = fp8_scaled_mm( + qinput, + weight, + x_scale, + weight_scale, + out_dtype=input.dtype, + bias=bias, + ) + return output.view(*output_shape) + except (ImportError, NameError, AttributeError): + pass # torch.scaled_mm supports per tensor weights + activations only # so fallback to naive if per channel or per token + per_tensor_weights = weight_scale.numel() == 1 + per_tensor_activations = x_scale.numel() == 1 + + if per_tensor_weights and per_tensor_activations: + # Fused GEMM_DQ + output = torch._scaled_mm( + qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + else: - per_tensor_weights = weight_scale.numel() == 1 - per_tensor_activations = x_scale.numel() == 1 + # Fallback for channelwise case, where we use unfused DQ + # due to limitations with scaled_mm - if per_tensor_weights and per_tensor_activations: - # Fused GEMM_DQ - output = torch._scaled_mm( - qinput, - weight, - out_dtype=input.dtype, - scale_a=x_scale, - scale_b=weight_scale, - bias=bias, - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. - return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) + # Making sure the dummy tensor is on the same device as the weight + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY.device != weight.device: + TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) - else: - # Fallback for channelwise case, where we use unfused DQ - # due to limitations with scaled_mm + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm( + qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32, + ) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) - # Symmetric quantized GEMM by definition computes the following: - # C = (s_x * X) (s_w * W) + bias - # This is equivalent to dequantizing the weights and activations - # before applying a GEMM. - # - # In order to compute quantized operands, a quantized kernel - # will rewrite the above like so: - # C = s_w * s_x * (X * W) + bias - # - # For the scaled_mm fallback case, we break this down, since it - # does not support s_w being a vector. - - # Making sure the dummy tensor is on the same device as the weight - global TORCH_DEVICE_IDENTITY - if TORCH_DEVICE_IDENTITY.device != weight.device: - TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) - - # GEMM - # This computes C = (X * W). - # Output in fp32 to allow subsequent ops to happen in-place - output = torch._scaled_mm( - qinput, - weight, - scale_a=TORCH_DEVICE_IDENTITY, - scale_b=TORCH_DEVICE_IDENTITY, - out_dtype=torch.float32, - ) - # A fix for discrepancy in scaled_mm which returns tuple - # for torch < 2.5 and a single value in torch >= 2.5 - if type(output) is tuple and len(output) == 2: - output = output[0] - # Unpad (undo num_token_padding) - output = torch.narrow(output, 0, 0, input_2d.shape[0]) - x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) - - # DQ - # C = sw * sx * (X * W) + bias - output = output * x_scale * weight_scale.t() - if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * weight_scale.t() + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) diff --git a/python/sglang/srt/layers/quantization/gptq.py b/python/sglang/srt/layers/quantization/gptq.py index b15498864..aecdd4cee 100644 --- a/python/sglang/srt/layers/quantization/gptq.py +++ b/python/sglang/srt/layers/quantization/gptq.py @@ -3,11 +3,21 @@ from fractions import Fraction from typing import Any, Dict, List, Optional, Union import torch -from vllm.scalar_type import scalar_types from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.utils import scalar_types from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead +from sglang.srt.utils import is_cuda + +_is_cuda = is_cuda() + +try: + import vllm + + VLLM_AVAILABLE = True +except ImportError: + VLLM_AVAILABLE = False logger = logging.getLogger(__name__) @@ -110,6 +120,9 @@ class GPTQConfig(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["GPTQLinearMethod"]: + if not VLLM_AVAILABLE: + raise ImportError("vllm is not installed") + from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod from sglang.srt.layers.quantization import get_linear_quant_method @@ -263,6 +276,9 @@ class GPTQMarlinConfig(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: + if not VLLM_AVAILABLE: + raise ImportError("vllm is not installed") + from vllm.model_executor.layers.quantization.gptq_marlin import ( GPTQMarlinLinearMethod, GPTQMarlinMoEMethod, @@ -285,6 +301,9 @@ class GPTQMarlinConfig(QuantizationConfig): @classmethod def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]): + if not VLLM_AVAILABLE: + return False + quant_method = quant_config.get("quant_method", "").lower() num_bits = quant_config.get("bits") group_size = quant_config.get("group_size") @@ -294,9 +313,8 @@ class GPTQMarlinConfig(QuantizationConfig): from vllm.model_executor.layers.quantization.utils.marlin_utils import ( check_marlin_supported, ) - from vllm.platforms import current_platform - if not current_platform.is_cuda(): + if not _is_cuda: return False if quant_method != "gptq": @@ -407,6 +425,9 @@ class MarlinConfig(QuantizationConfig): def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional["MarlinLinearMethod"]: + if not VLLM_AVAILABLE: + raise ImportError("vllm is not installed") + from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod if isinstance(layer, LinearBase) or ( diff --git a/python/sglang/srt/layers/quantization/kv_cache.py b/python/sglang/srt/layers/quantization/kv_cache.py new file mode 100644 index 000000000..4275bca52 --- /dev/null +++ b/python/sglang/srt/layers/quantization/kv_cache.py @@ -0,0 +1,98 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/kv_cache.py + +import logging + +import torch + +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.utils import is_hip + +_is_hip = is_hip() + +logger = logging.getLogger(__name__) + + +class BaseKVCacheMethod(QuantizeMethodBase): + """ + Quant method that adds `_k_scale` and `_v_scale` attributes to the + Attention layer to support loading those scaling factors from checkpoints. + The k/v_scale will be used to: + - quantize k/v_cache entries before saving them to the cache + - dequantize k/v_cache entries before fetching them from the cache + + :param quant_config: the appropriate QuantizationConfig + """ + + def __init__(self, quant_config: QuantizationConfig): + self.quant_config = quant_config + + def create_weights(self, layer: torch.nn.Module): + """ + Create "weight" (aka k_scale and v_scale) for an attention layer. + """ + # Initialize the KV cache scales to -1.0, which is an invalid value. + # If the k/v_scale appears in the checkpoint, it will be + # overwritten when loading weights. + layer.k_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + layer.v_scale = torch.nn.Parameter(torch.tensor(-1.0), requires_grad=False) + + @classmethod + def is_fp8_fnuz(cls) -> bool: + # only device 0 is checked, this assumes MI300 platforms are homogeneous + return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName + + def apply(self, layer: torch.nn.Module) -> torch.Tensor: + raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # If the kv-cache dtype is auto, we enforce the k/v_scale to be 1.0 + # regardless whether the kv-scale is available in the checkpoint. + # No need to process kv scales after loading if we are going to + # calculate them on the fly. + if layer.kv_cache_dtype != "auto" and not layer.calculate_kv_scales: + if layer.k_scale > 0.0 and layer.v_scale > 0.0: + # We prefer to use separate k_scale and v_scale if present + k_scale = layer.k_scale.to("cpu").tolist() + v_scale = layer.v_scale.to("cpu").tolist() + if _is_hip and self.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + elif layer.k_scale < 0.0 and layer.v_scale < 0.0: + # If no scales were loaded (both scales are invalid negative + # values), use the default value of 1.0 + k_scale = 1.0 + v_scale = 1.0 + else: + # If we find a single kv_scale in the checkpoint, we remap + # kv_scale to k_scale during weight loading, and duplicate + # k_scale to v_scale here + assert layer.k_scale > 0.0 + scale_to_duplicate = max(layer.k_scale, layer.v_scale) + k_scale = scale_to_duplicate.to("cpu").tolist() + v_scale = scale_to_duplicate.to("cpu").tolist() + if _is_hip and self.is_fp8_fnuz(): + k_scale *= 2 + v_scale *= 2 + + if not isinstance(k_scale, float) or not isinstance(v_scale, float): + raise ValueError( + "Only support per-tensor scaling factor " "for fp8 KV cache" + ) + + # These are used in the final Attention.forward() + layer._k_scale.copy_(k_scale) + layer._v_scale.copy_(v_scale) + layer._k_scale_float = k_scale + layer._v_scale_float = v_scale + if k_scale == 1.0 and v_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype: + logger.warning( + "Using KV cache scaling factor 1.0 for fp8_e4m3. This " + "may cause accuracy issues. Please make sure k/v_scale " + "scaling factors are available in the fp8 checkpoint." + ) + + del layer.k_scale + del layer.v_scale diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index c26012da2..9961054d8 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -5,12 +5,6 @@ from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter -from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - convert_to_channelwise, - cutlass_fp8_supported, - requantize_with_max_scale, -) from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.linear import LinearBase, LinearMethodBase @@ -19,7 +13,15 @@ from sglang.srt.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) -from sglang.srt.layers.quantization.fp8_utils import apply_fp8_linear +from sglang.srt.layers.quantization.fp8_utils import ( + apply_fp8_linear, + cutlass_fp8_supported, +) +from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod +from sglang.srt.layers.quantization.utils import ( + convert_to_channelwise, + requantize_with_max_scale, +) # Initialize logger for the module logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/layers/quantization/utils.py b/python/sglang/srt/layers/quantization/utils.py new file mode 100644 index 000000000..abe49e80f --- /dev/null +++ b/python/sglang/srt/layers/quantization/utils.py @@ -0,0 +1,442 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py +# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/scalar_type.py + +import functools +import struct +from dataclasses import dataclass +from enum import Enum +from types import MappingProxyType +from typing import List, Mapping, Optional, Tuple, Union + +import torch + + +def is_layer_skipped( + prefix: str, + ignored_layers: List[str], + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}), +) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in fused_mapping[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision." + ) + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def per_tensor_dequantize( + tensor: torch.Tensor, inv_scale: Union[float, torch.Tensor] +) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + + +def convert_to_channelwise( + weight_scale: torch.Tensor, logical_widths: List[int] +) -> Tuple[torch.Tensor, torch.Tensor]: + # Create channelwise buffer + weight_scale_channel = torch.empty( + (sum(logical_widths), 1), dtype=torch.float32, device=weight_scale.device + ) + + # Expand each scale to match the size of each logical matrix. + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_scale_channel[start:end, :] = weight_scale[idx] + start = end + + return weight_scale_channel + + +def requantize_with_max_scale( + weight: torch.Tensor, weight_scale: torch.Tensor, logical_widths: List[int] +) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max() + + # QKV / MLP is fused in the on disk checkpoint if any of the + # weight scales are still set to the default since we initialize + # N weight scales for N shards but we only load 1 weight scale + # from disk in this case. Skip requantization in this case (since) + # we already are quantized with the single scale. + # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 + unfused_module_in_checkpoint = ( + weight_scale[-1] > torch.finfo(torch.float8_e4m3fn).min + ) + + # If unfused checkpoint, need requanize with the single scale. + if unfused_module_in_checkpoint: + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(weight[start:end, :], weight_scale[idx]) + weight[start:end, :], _ = ops.scaled_fp8_quant(weight_dq, max_w_scale) + start = end + + return max_w_scale, weight + + +# Mirrors enum in `core/scalar_type.hpp` +class NanRepr(Enum): + NONE = 0 # nans are not supported + IEEE_754 = 1 # nans are: Exp all 1s, mantissa not all 0s + EXTD_RANGE_MAX_MIN = 2 # nans are: Exp all 1s, mantissa all 1s + + +# This ScalarType class is a parallel implementation of the C++ ScalarType +# class found in csrc/core/scalar_type.hpp. These two classes should be kept +# in sync until the inductor fully supports custom C++ classes. +@dataclass(frozen=True) +class ScalarType: + """ + ScalarType can represent a wide range of floating point and integer + types, in particular it can be used to represent sub-byte data types + (something that torch.dtype currently does not support). It is also + capable of representing types with a bias, i.e.: + `stored_value = value + bias`, + this is useful for quantized types (e.g. standard GPTQ 4bit uses a bias + of 8). The implementation for this class can be found in + csrc/core/scalar_type.hpp, these type signatures should be kept in sync + with that file. + """ + + exponent: int + """ + Number of bits in the exponent if this is a floating point type + (zero if this an integer type) + """ + + mantissa: int + """ + Number of bits in the mantissa if this is a floating point type, + or the number bits representing an integer excluding the sign bit if + this an integer type. + """ + + signed: bool + "If the type is signed (i.e. has a sign bit)" + + bias: int + """ + bias used to encode the values in this scalar type + (value = stored_value - bias, default 0) for example if we store the + type as an unsigned integer with a bias of 128 then the value 0 will be + stored as 128 and -1 will be stored as 127 and 1 will be stored as 129. + """ + + _finite_values_only: bool = False + """ + Private: if infs are supported, used `has_infs()` instead. + """ + + nan_repr: NanRepr = NanRepr.IEEE_754 + """ + How NaNs are represent in this scalar type, returns NanRepr value. + (not applicable for integer types) + """ + + def _floating_point_max_int(self) -> int: + assert ( + self.mantissa <= 52 and self.exponent <= 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + + max_mantissa = (1 << self.mantissa) - 1 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN: + max_mantissa = max_mantissa - 1 + + max_exponent = (1 << self.exponent) - 2 + if self.nan_repr == NanRepr.EXTD_RANGE_MAX_MIN or self.nan_repr == NanRepr.NONE: + assert ( + self.exponent < 11 + ), f"Cannot represent max/min as a double for type {self.__str__()}" + max_exponent = max_exponent + 1 + + # adjust the exponent to match that of a double + # for now we assume the exponent bias is the standard 2^(e-1) -1, (where + # e is the exponent bits), there is some precedent for non-standard + # biases, example `float8_e4m3b11fnuz` here: + # https://github.com/jax-ml/ml_dtypes but to avoid premature over + # complication we are just assuming the standard exponent bias until + # there is a need to support non-standard biases + exponent_bias = (1 << (self.exponent - 1)) - 1 + exponent_bias_double = (1 << 10) - 1 # double e = 11 + + max_exponent_double = max_exponent - exponent_bias + exponent_bias_double + + # shift the mantissa and exponent into the proper positions for an + # IEEE double and bitwise-or them together. + return (max_mantissa << (52 - self.mantissa)) | (max_exponent_double << 52) + + def _floating_point_max(self) -> float: + double_raw = self._floating_point_max_int() + return struct.unpack("!d", struct.pack("!Q", double_raw))[0] + + def _raw_max(self) -> Union[int, float]: + if self.is_floating_point(): + return self._floating_point_max() + else: + assert ( + self.size_bits < 64 or self.size_bits == 64 and self.is_signed() + ), "Cannot represent max as an int" + return (1 << self.mantissa) - 1 + + def _raw_min(self) -> Union[int, float]: + if self.is_floating_point(): + assert ( + self.is_signed() + ), "We currently assume all floating point types are signed" + sign_bit_double = 1 << 63 + + max_raw = self._floating_point_max_int() + min_raw = max_raw | sign_bit_double + return struct.unpack("!d", struct.pack("!Q", min_raw))[0] + else: + assert ( + not self.is_signed() or self.size_bits <= 64 + ), "Cannot represent min as a int64_t" + + if self.is_signed(): + return -(1 << (self.size_bits - 1)) + else: + return 0 + + @functools.cached_property + def id(self) -> int: + """ + Convert the ScalarType to an int which can be passed to pytorch custom + ops. This layout of the int must be kept in sync with the C++ + ScalarType's from_id method. + """ + val = 0 + offset = 0 + + def or_and_advance(member, bit_width): + nonlocal val + nonlocal offset + bit_mask = (1 << bit_width) - 1 + val = val | (int(member) & bit_mask) << offset + offset = offset + bit_width + + or_and_advance(self.exponent, 8) + or_and_advance(self.mantissa, 8) + or_and_advance(self.signed, 1) + or_and_advance(self.bias, 32) + or_and_advance(self._finite_values_only, 1) + or_and_advance(self.nan_repr.value, 8) + + assert offset <= 64, f"ScalarType fields too big {offset} to fit into an int64" + + return val + + @property + def size_bits(self) -> int: + return self.exponent + self.mantissa + int(self.signed) + + def min(self) -> Union[int, float]: + """ + Min representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_min() - self.bias + + def max(self) -> Union[int, float]: + """ + Max representable value for this scalar type. + (accounting for bias if there is one) + """ + return self._raw_max() - self.bias + + def is_signed(self) -> bool: + """ + If the type is signed (i.e. has a sign bit), same as `signed` + added for consistency with: + https://pytorch.org/docs/stable/generated/torch.Tensor.is_signed.html + """ + return self.signed + + def is_floating_point(self) -> bool: + "If the type is a floating point type" + return self.exponent != 0 + + def is_integer(self) -> bool: + "If the type is an integer type" + return self.exponent == 0 + + def has_bias(self) -> bool: + "If the type has a non-zero bias" + return self.bias != 0 + + def has_infs(self) -> bool: + "If the type is floating point and supports infinity" + return not self._finite_values_only + + def has_nans(self) -> bool: + return self.nan_repr != NanRepr.NONE.value + + def is_ieee_754(self) -> bool: + """ + If the type is a floating point type that follows IEEE 754 + conventions + """ + return self.nan_repr == NanRepr.IEEE_754.value and not self._finite_values_only + + def __str__(self) -> str: + """ + naming generally follows: https://github.com/jax-ml/ml_dtypes + for floating point types (leading f) the scheme is: + `float_em[flags]` + flags: + - no-flags: means it follows IEEE 754 conventions + - f: means finite values only (no infinities) + - n: means nans are supported (non-standard encoding) + for integer types the scheme is: + `[u]int[b]` + - if bias is not present it means its zero + """ + if self.is_floating_point(): + ret = ( + "float" + + str(self.size_bits) + + "_e" + + str(self.exponent) + + "m" + + str(self.mantissa) + ) + + if not self.is_ieee_754(): + if self._finite_values_only: + ret = ret + "f" + if self.nan_repr != NanRepr.NONE: + ret = ret + "n" + + return ret + else: + ret = ("int" if self.is_signed() else "uint") + str(self.size_bits) + if self.has_bias(): + ret = ret + "b" + str(self.bias) + return ret + + def __repr__(self) -> str: + return "ScalarType." + self.__str__() + + # __len__ needs to be defined (and has to throw TypeError) for pytorch's + # opcheck to work. + def __len__(self) -> int: + raise TypeError + + # + # Convenience Constructors + # + + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": + "Create a signed integer scalar type (size_bits includes sign-bit)." + ret = cls(0, size_bits - 1, True, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": + """Create a unsigned integer scalar type.""" + ret = cls(0, size_bits, False, bias if bias else 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType": + """ + Create a standard floating point type + (i.e. follows IEEE 754 conventions). + """ + assert mantissa > 0 and exponent > 0 + ret = cls(exponent, mantissa, True, 0) + ret.id # noqa B018: make sure the id is cached + return ret + + @classmethod + def float_( + cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr + ) -> "ScalarType": + """ + Create a non-standard floating point type + (i.e. does not follow IEEE 754 conventions). + """ + assert mantissa > 0 and exponent > 0 + assert nan_repr != NanRepr.IEEE_754, ( + "use `float_IEEE754` constructor for floating point types that " + "follow IEEE 754 conventions" + ) + ret = cls(exponent, mantissa, True, 0, finite_values_only, nan_repr) + ret.id # noqa B018: make sure the id is cached + return ret + + +# naming generally follows: https://github.com/jax-ml/ml_dtypes +# for floating point types (leading f) the scheme is: +# `float_em[flags]` +# flags: +# - no-flags: means it follows IEEE 754 conventions +# - f: means finite values only (no infinities) +# - n: means nans are supported (non-standard encoding) +# for integer types the scheme is: +# `[u]int[b]` +# - if bias is not present it means its zero + + +class scalar_types: + int4 = ScalarType.int_(4, None) + uint4 = ScalarType.uint(4, None) + int8 = ScalarType.int_(8, None) + uint8 = ScalarType.uint(8, None) + float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN) + float8_e5m2 = ScalarType.float_IEEE754(5, 2) + float16_e8m7 = ScalarType.float_IEEE754(8, 7) + float16_e5m10 = ScalarType.float_IEEE754(5, 10) + + # fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main + float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE) + + # fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf + float4_e2m1fn = ScalarType.float_(2, 1, True, NanRepr.NONE) + + # "gptq" types + uint2b2 = ScalarType.uint(2, 2) + uint3b4 = ScalarType.uint(3, 4) + uint4b8 = ScalarType.uint(4, 8) + uint8b128 = ScalarType.uint(8, 128) + + # colloquial names + bfloat16 = float16_e8m7 + float16 = float16_e5m10