diff --git a/python/sglang/srt/layers/linear.py b/python/sglang/srt/layers/linear.py index 02557d107..f8f2d9f6d 100644 --- a/python/sglang/srt/layers/linear.py +++ b/python/sglang/srt/layers/linear.py @@ -38,6 +38,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", + "BlockInt8LinearMethod", "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 77ce7954f..805a43e45 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -15,7 +15,13 @@ from vllm import _custom_ops as ops from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip +from sglang.srt.layers.quantization.int8_kernel import per_token_group_quant_int8 +from sglang.srt.utils import ( + direct_register_custom_op, + get_device_name, + is_cuda_available, + is_hip, +) is_hip_flag = is_hip() @@ -86,6 +92,7 @@ def fused_moe_kernel( top_k: tl.constexpr, compute_type: tl.constexpr, use_fp8_w8a8: tl.constexpr, + use_int8_w8a8: tl.constexpr, use_int8_w8a16: tl.constexpr, even_Ks: tl.constexpr, ): @@ -159,7 +166,7 @@ def fused_moe_kernel( ) b_scale = tl.load(b_scale_ptrs) - if use_fp8_w8a8: + if use_fp8_w8a8 or use_int8_w8a8: if group_k > 0 and group_n > 0: a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm offs_bsn = offs_bn // group_n @@ -198,7 +205,7 @@ def fused_moe_kernel( # We accumulate along the K dimension. if use_int8_w8a16: accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) - elif use_fp8_w8a8: + elif use_fp8_w8a8 or use_int8_w8a8: if group_k > 0 and group_n > 0: k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k @@ -221,7 +228,7 @@ def fused_moe_kernel( accumulator = accumulator * moe_weight[:, None] if use_int8_w8a16: accumulator = (accumulator * b_scale).to(compute_type) - elif use_fp8_w8a8: + elif use_fp8_w8a8 or use_int8_w8a8: if group_k > 0 and group_n > 0: accumulator = accumulator.to(compute_type) else: @@ -477,6 +484,7 @@ def invoke_fused_moe_kernel( config: Dict[str, Any], compute_type: tl.dtype, use_fp8_w8a8: bool, + use_int8_w8a8: bool, use_int8_w8a16: bool, block_shape: Optional[List[int]] = None, ) -> None: @@ -499,6 +507,18 @@ def invoke_fused_moe_kernel( assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] + elif use_int8_w8a8: + assert B_scale is not None + if block_shape is None: + padded_size = padding_size + A, A_scale = ops.scaled_int8_quant(A, A_scale) + else: + assert len(block_shape) == 2 + block_n, block_k = block_shape[0], block_shape[1] + A, A_scale = per_token_group_quant_int8(A, block_k) + assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1] + assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2] + assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1] elif use_int8_w8a16: assert B_scale is not None else: @@ -548,6 +568,7 @@ def invoke_fused_moe_kernel( top_k=top_k, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, even_Ks=even_Ks, **config, @@ -701,9 +722,12 @@ def get_config_dtype_str( dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False, + use_int8_w8a8: Optional[bool] = False, ): if use_fp8_w8a8: return "fp8_w8a8" + elif use_int8_w8a8: + return "int8_w8a8" elif use_int8_w8a16: return "int8_w8a16" elif dtype == torch.float: @@ -721,6 +745,7 @@ def inplace_fused_experts( topk_ids: torch.Tensor, activation: str = "silu", use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -737,6 +762,7 @@ def inplace_fused_experts( True, activation, use_fp8_w8a8, + use_int8_w8a8, use_int8_w8a16, w1_scale, w2_scale, @@ -754,6 +780,7 @@ def inplace_fused_experts_fake( topk_ids: torch.Tensor, activation: str = "silu", use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -780,6 +807,7 @@ def outplace_fused_experts( topk_ids: torch.Tensor, activation: str = "silu", use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -796,6 +824,7 @@ def outplace_fused_experts( False, activation, use_fp8_w8a8, + use_int8_w8a8, use_int8_w8a16, w1_scale, w2_scale, @@ -813,6 +842,7 @@ def outplace_fused_experts_fake( topk_ids: torch.Tensor, activation: str = "silu", use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -840,6 +870,7 @@ def fused_experts( inplace: bool = False, activation: str = "silu", use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -856,6 +887,7 @@ def fused_experts( topk_ids, activation, use_fp8_w8a8, + use_int8_w8a8, use_int8_w8a16, w1_scale, w2_scale, @@ -873,6 +905,7 @@ def fused_experts( topk_ids, activation, use_fp8_w8a8, + use_int8_w8a8, use_int8_w8a16, w1_scale, w2_scale, @@ -891,6 +924,7 @@ def fused_experts_impl( inplace: bool = False, activation: str = "silu", use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -899,7 +933,7 @@ def fused_experts_impl( block_shape: Optional[List[int]] = None, ): padded_size = padding_size - if not use_fp8_w8a8 or block_shape is not None: + if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None: padded_size = 0 # Check constraints. @@ -918,6 +952,7 @@ def fused_experts_impl( M = min(num_tokens, CHUNK_SIZE) config_dtype = get_config_dtype_str( use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, dtype=hidden_states.dtype, ) @@ -1001,6 +1036,7 @@ def fused_experts_impl( config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, block_shape=block_shape, ) @@ -1034,6 +1070,7 @@ def fused_experts_impl( config, compute_type=compute_type, use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, block_shape=block_shape, ) @@ -1078,6 +1115,7 @@ def fused_moe( topk_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, + use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, @@ -1105,6 +1143,8 @@ def fused_moe( note: Deepseek V2/V3/R1 series models use grouped_topk - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. + - use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner products for w1 and w2. Defaults to False. - w1_scale (Optional[torch.Tensor]): Optional scale to be used for @@ -1144,6 +1184,7 @@ def fused_moe( inplace=inplace, activation=activation, use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a8=use_int8_w8a8, use_int8_w8a16=use_int8_w8a16, w1_scale=w1_scale, w2_scale=w2_scale, diff --git a/python/sglang/srt/layers/quantization/__init__.py b/python/sglang/srt/layers/quantization/__init__.py index da6f29312..928643b70 100644 --- a/python/sglang/srt/layers/quantization/__init__.py +++ b/python/sglang/srt/layers/quantization/__init__.py @@ -24,6 +24,7 @@ from vllm.model_executor.layers.quantization.qqq import QQQConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config from sglang.srt.layers.quantization.fp8 import Fp8Config from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config @@ -34,6 +35,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "deepspeedfp": DeepSpeedFPConfig, "tpu_int8": Int8TpuConfig, "fp8": Fp8Config, + "blockwise_int8": BlockInt8Config, "fbgemm_fp8": FBGEMMFp8Config, "marlin": MarlinConfig, "modelopt": ModelOptFp8Config, diff --git a/python/sglang/srt/layers/quantization/blockwise_int8.py b/python/sglang/srt/layers/quantization/blockwise_int8.py new file mode 100644 index 000000000..ef03a3610 --- /dev/null +++ b/python/sglang/srt/layers/quantization/blockwise_int8.py @@ -0,0 +1,406 @@ +# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py + +import logging +from typing import Any, Callable, Dict, List, Optional + +import torch +from torch.nn import Module +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped + +from sglang.srt.distributed import get_tensor_model_parallel_world_size +from sglang.srt.layers.linear import ( + LinearBase, + LinearMethodBase, + UnquantizedLinearMethod, +) +from sglang.srt.layers.parameter import ModelWeightParameter, PerTensorScaleParameter +from sglang.srt.layers.quantization.base_config import ( + QuantizationConfig, + QuantizeMethodBase, +) +from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter +from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear +from sglang.srt.utils import set_weight_attrs + +ACTIVATION_SCHEMES = ["static", "dynamic"] + +logger = logging.getLogger(__name__) + + +class BlockInt8Config(QuantizationConfig): + """Config class for INT8.""" + + def __init__( + self, + is_checkpoint_int8_serialized: bool = False, + activation_scheme: str = "dynamic", + ignored_layers: Optional[List[str]] = None, + weight_block_size: List[int] = None, + ) -> None: + self.is_checkpoint_int8_serialized = is_checkpoint_int8_serialized + if is_checkpoint_int8_serialized: + logger.warning( + "Detected int8 checkpoint. Please note that the " + "format is experimental and subject to change." + ) + if activation_scheme not in ACTIVATION_SCHEMES: + raise ValueError(f"Unsupported activation scheme {activation_scheme}") + self.activation_scheme = activation_scheme + self.ignored_layers = ignored_layers or [] + if weight_block_size is not None: + if not is_checkpoint_int8_serialized: + raise ValueError( + f"The block-wise quantization only supports int8-serialized checkpoint for now." + ) + if len(weight_block_size) != 2: + raise ValueError( + f"The quantization block size of weight must have 2 dimensions, but got {len(weight_block_size)} dimensions." + ) + if activation_scheme != "dynamic": + raise ValueError( + f"The block-wise quantization only supports dynamic activation scheme for now, but got {activation_scheme} activation scheme." + ) + self.weight_block_size = weight_block_size + + @classmethod + def get_name(cls) -> str: + return "blockwise_int8" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.bfloat16, torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config": + quant_method = cls.get_from_keys(config, ["quant_method"]) + is_checkpoint_int8_serialized = "int8" in quant_method + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + ignored_layers = cls.get_from_keys_or(config, ["ignored_layers"], None) + weight_block_size = cls.get_from_keys_or(config, ["weight_block_size"], None) + return cls( + is_checkpoint_int8_serialized=is_checkpoint_int8_serialized, + activation_scheme=activation_scheme, + ignored_layers=ignored_layers, + weight_block_size=weight_block_size, + ) + + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional["QuantizeMethodBase"]: + from sglang.srt.layers.moe.fused_moe_triton import FusedMoE + + if isinstance(layer, LinearBase): + if is_layer_skipped(prefix, self.ignored_layers): + return UnquantizedLinearMethod() + return BlockInt8LinearMethod(self) + elif isinstance(layer, FusedMoE): + return BlockInt8MoEMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class BlockInt8LinearMethod(LinearMethodBase): + """Linear method for INT8. + Supports loading INT8 checkpoints with static weight scale and + dynamic activation scale. + + Limitations: + Only support block-wise int8 quantization and int8 checkpoint + + Args: + quant_config: The quantization config. + """ + + def __init__(self, quant_config: BlockInt8Config): + self.quant_config = quant_config + assert self.quant_config.weight_block_size is not None + assert self.quant_config.is_checkpoint_int8_serialized + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + tp_size = get_tensor_model_parallel_world_size() + + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # Required by row parallel + if tp_size > 1 and input_size // input_size_per_partition == tp_size: + if input_size_per_partition % block_k != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + # Required by collum parallel or enabling merged weights + if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len( + output_partition_sizes + ) > 1: + for output_partition_size in output_partition_sizes: + if output_partition_size % block_n != 0: + raise ValueError( + f"Weight output_partition_size = " + f"{output_partition_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + + layer.logical_widths = output_partition_sizes + + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.orig_dtype = params_dtype + + # WEIGHT + weight_dtype = ( + torch.int8 + if self.quant_config.is_checkpoint_int8_serialized + else params_dtype + ) + + weight = ModelWeightParameter( + data=torch.empty( + output_size_per_partition, input_size_per_partition, dtype=weight_dtype + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + + scale = BlockQuantScaleParameter( + data=torch.empty( + (output_size_per_partition + block_n - 1) // block_n, + (input_size_per_partition + block_k - 1) // block_k, + dtype=torch.float32, + ), + input_dim=1, + output_dim=0, + weight_loader=weight_loader, + ) + scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale_inv", scale) + + # INPUT ACTIVATION SCALE + assert self.quant_config.activation_scheme == "dynamic" + layer.register_parameter("input_scale", None) + + def process_weights_after_loading(self, layer: Module) -> None: + # Block quant doesn't need to process weights after loading + # Use torch Parameter to avoid cuda graph capturing issue + layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False) + layer.weight_scale_inv = torch.nn.Parameter( + layer.weight_scale_inv.data, requires_grad=False + ) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return apply_w8a8_block_int8_linear( + input=x, + weight=layer.weight, + block_size=self.quant_config.weight_block_size, + weight_scale=layer.weight_scale_inv, + input_scale=None, + bias=bias, + ) + + +class BlockInt8MoEMethod: + """MoE method for INT8. + Supports loading INT8 checkpoints with static weight scale and + dynamic activation scale. + + Limitations: + Only support block-wise int8 quantization and int8 checkpoint + + Args: + quant_config: The quantization config. + """ + + def __new__(cls, *args, **kwargs): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase + + if not hasattr(cls, "_initialized"): + original_init = cls.__init__ + new_cls = type( + cls.__name__, + (FusedMoEMethodBase,), + { + "__init__": original_init, + **{k: v for k, v in cls.__dict__.items() if k != "__dict__"}, + }, + ) + obj = super(new_cls, new_cls).__new__(new_cls) + obj.__init__(*args, **kwargs) + return obj + return super().__new__(cls) + + def __init__(self, quant_config): + self.quant_config = quant_config + assert self.quant_config.weight_block_size is not None + assert self.quant_config.is_checkpoint_int8_serialized + + def create_weights( + self, + layer: Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported + + if self.quant_config.is_checkpoint_int8_serialized: + params_dtype = torch.int8 + tp_size = get_tensor_model_parallel_world_size() + + block_n, block_k = ( + self.quant_config.weight_block_size[0], + self.quant_config.weight_block_size[1], + ) + # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n. + # Required by collum parallel or enabling merged weights + if intermediate_size % block_n != 0: + raise ValueError( + f"The output_size of gate's and up's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_n = {block_n}." + ) + if tp_size > 1: + # Required by row parallel + if intermediate_size % block_k != 0: + raise ValueError( + f"The input_size of down's weight = " + f"{intermediate_size} is not divisible by " + f"weight quantization block_k = {block_k}." + ) + + # WEIGHTS + w13_weight = torch.nn.Parameter( + torch.empty( + num_experts, 2 * intermediate_size, hidden_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter( + torch.empty( + num_experts, hidden_size, intermediate_size, dtype=params_dtype + ), + requires_grad=False, + ) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + w13_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + 2 * ((intermediate_size + block_n - 1) // block_n), + (hidden_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + w2_weight_scale = torch.nn.Parameter( + torch.ones( + num_experts, + (hidden_size + block_n - 1) // block_n, + (intermediate_size + block_k - 1) // block_k, + dtype=torch.float32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) + layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) + + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value} + ) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + assert self.quant_config.activation_scheme == "dynamic" + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: Module) -> None: + # Block quant doesn't need to process weights after loading + return + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + correction_bias: Optional[torch.Tensor] = None, + activation: str = "silu", + ) -> torch.Tensor: + from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts + from sglang.srt.layers.moe.topk import select_experts + + # Expert selection + topk_weights, topk_ids = select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + correction_bias=correction_bias, + ) + + # Expert fusion with INT8 quantization + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_int8_w8a8=True, + w1_scale=(layer.w13_weight_scale_inv), + w2_scale=(layer.w2_weight_scale_inv), + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale, + block_shape=self.quant_config.weight_block_size, + ) diff --git a/python/sglang/srt/layers/quantization/int8_kernel.py b/python/sglang/srt/layers/quantization/int8_kernel.py index 91b56f9e0..79b03f64a 100644 --- a/python/sglang/srt/layers/quantization/int8_kernel.py +++ b/python/sglang/srt/layers/quantization/int8_kernel.py @@ -1,7 +1,17 @@ +import functools +import json +import logging +import os +from typing import Any, Dict, List, Optional, Tuple + import torch import triton import triton.language as tl +from sglang.srt.utils import get_device_name + +logger = logging.getLogger(__name__) + @triton.jit def _per_token_quant_int8( @@ -52,3 +62,320 @@ def per_token_quant_int8(x): ) return x_q, scales + + +@triton.jit +def _per_token_group_quant_int8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + # Stride of input + y_stride, + # Collums of input + N, + # Avoid to divide zero + eps, + # Information for int8 + int8_min, + int8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group quantization on a + tensor. + + This function converts the tensor values into int8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * y_stride + y_q_ptr += g_id * y_stride + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < N + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / int8_max + y_q = tl.clamp(y / y_s, int8_min, int8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_int8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: torch.dtype = torch.int8, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + + It converts the tensor values into signed int8 values and returns the + quantized tensor along with the scaling factor used for quantization. + + Args: + x: The input tenosr with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.int8` is supported for now. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. + """ + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + iinfo = torch.iinfo(dtype) + int8_max = iinfo.max + int8_min = iinfo.min + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + _per_token_group_quant_int8[(M,)]( + x, + x_q, + x_s, + group_size, + N, + eps, + int8_min=int8_min, + int8_max=int8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +@triton.jit +def _w8a8_block_int8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and store the result in output + tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b).to(tl.float32) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@functools.lru_cache +def get_w8a8_block_int8_configs( + N: int, K: int, block_n: int, block_k: int +) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the w8a8 block fp8 kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + device_name = get_device_name().replace(" ", "_") + json_file_name = f"N={N},K={K},device_name={device_name},dtype=int8_w8a8,block_shape=[{block_n}, {block_k}].json" + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name + ) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info( + "Using configuration from %s for W8A8 Block INT8 kernel.", + config_file_path, + ) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.warning( + ( + "Using default W8A8 Block INT8 kernel config. Performance might be sub-optimal! " + "Config file not found at %s" + ), + config_file_path, + ) + return None + + +def w8a8_block_int8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise quantization. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N,) + C = A.new_empty(C_shape, dtype=output_dtype) + + configs = get_w8a8_block_int8_configs(N, K, block_size[0], block_size[1]) + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Default config + # Block-wise quant: BLOCK_SIZE_K must be divisable by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + } + + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + + _w8a8_block_int8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + return C diff --git a/python/sglang/srt/layers/quantization/int8_utils.py b/python/sglang/srt/layers/quantization/int8_utils.py new file mode 100644 index 000000000..9a7aa6d82 --- /dev/null +++ b/python/sglang/srt/layers/quantization/int8_utils.py @@ -0,0 +1,73 @@ +from typing import List, Optional, Tuple + +import torch + +from sglang.srt.layers.quantization.int8_kernel import ( + per_token_group_quant_int8, + w8a8_block_int8_matmul, +) + + +def apply_w8a8_block_int8_linear( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + q_input, x_scale = per_token_group_quant_int8(input_2d, block_size[1]) + output = w8a8_block_int8_matmul( + q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype + ) + + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + +def input_to_int8( + x: torch.Tensor, dtype: torch.dtype = torch.int8 +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function quantizes input values to int8 values with tensor-wise quantization.""" + iinfo = torch.iinfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + int8_min, int8_max = iinfo.min, iinfo.max + scale = int8_max / amax + x_scl_sat = (x * scale).clamp(min=int8_min, max=int8_max) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + +def block_dequant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, + block_size: List[int], +) -> torch.Tensor: + """This function conducts block-wise dequantization. + The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale + and the block size. + The outputs are dequantized tensor. + """ + block_n, block_k = block_size[0], block_size[1] + n, k = x_q_block.shape + n_tiles = (n + block_n - 1) // block_n + k_tiles = (k + block_k - 1) // block_k + assert n_tiles == x_s.shape[0] + assert k_tiles == x_s.shape[1] + + x_dq_block = x_q_block.to(torch.float32) + + for i in range(k_tiles): + for j in range(n_tiles): + x_dq_block[ + j * block_n : min((j + 1) * block_n, n), + i * block_k : min((i + 1) * block_k, k), + ] *= x_s[j][i] + + return x_dq_block diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index afb97ed03..6a6b5f387 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -47,6 +47,9 @@ from sglang.srt.layers.quantization.fp8_utils import ( input_to_float8, normalize_e4m3fn_to_e4m3fnuz, ) +from sglang.srt.layers.quantization.int8_utils import ( + block_dequant as int8_block_dequant, +) from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper from sglang.srt.layers.vocab_parallel_embedding import ( @@ -994,6 +997,18 @@ class DeepseekV2ForCausalLM(nn.Module): weight, weight_scale, weight_block_size ) self_attn.w_scale = scale + if ( + hasattr(self.quant_config, "weight_block_size") + and w.dtype == torch.int8 + ): + weight_block_size = self.quant_config.weight_block_size + if weight_block_size is not None: + assert hasattr(self_attn.kv_b_proj, "weight_scale_inv") + weight = w + weight_scale = self_attn.kv_b_proj.weight_scale_inv + w = int8_block_dequant( + weight, weight_scale, weight_block_size + ).to(torch.bfloat16) w_kc, w_vc = w.unflatten( 0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim) ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index d6f3fbf8e..506b87bf6 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -55,6 +55,7 @@ suites = { "test_vision_openai_server.py", "test_w8a8_quantization.py", "test_fp8_kernel.py", + "test_block_int8.py", ], "nightly": [ "test_nightly_gsm8k_eval.py", diff --git a/test/srt/test_block_int8.py b/test/srt/test_block_int8.py new file mode 100644 index 000000000..1c6c7656c --- /dev/null +++ b/test/srt/test_block_int8.py @@ -0,0 +1,221 @@ +import itertools +import unittest + +import torch + +from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe + + +# For test +def native_per_token_group_quant_int8(x, group_size, eps=1e-10, dtype=torch.int8): + """Function to perform per-token-group quantization on an input tensor `x` using native torch. + + It converts the tensor values into float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Note that only `torch.float8_e4m3fn` is supported for now. + """ + assert ( + x.shape[-1] % group_size == 0 + ), "the last dimension of `x` cannot be divisible by `group_size`" + assert x.is_contiguous(), "`x` is not contiguous" + + iinfo = torch.iinfo(dtype) + int8_min = iinfo.min + int8_max = iinfo.max + + x_ = x.reshape(x.numel() // group_size, group_size) + amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32) + x_s = amax / int8_max + x_q = (x_ / x_s).clamp(min=int8_min, max=int8_max).to(dtype) + x_q = x_q.reshape(x.shape) + x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,)) + + return x_q, x_s + + +# For test +def native_w8a8_block_int8_matmul(A, B, As, Bs, block_size, output_dtype=torch.float16): + """This function performs matrix multiplication with block-wise quantization using native torch. + + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + """ + + A = A.to(torch.float32) + B = B.to(torch.float32) + assert A.shape[-1] == B.shape[-1] + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + assert (A.shape[-1] + block_k - 1) // block_k == As.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + + M = A.numel() // A.shape[-1] + N, K = B.shape + origin_C_shape = A.shape[:-1] + (N,) + A = A.reshape(M, A.shape[-1]) + As = As.reshape(M, As.shape[-1]) + n_tiles = (N + block_n - 1) // block_n + k_tiles = (K + block_k - 1) // block_k + assert n_tiles == Bs.shape[0] + assert k_tiles == Bs.shape[1] + + C_shape = (M, N) + C = torch.zeros(C_shape, dtype=torch.float32, device=A.device) + + A_tiles = [A[:, i * block_k : min((i + 1) * block_k, K)] for i in range(k_tiles)] + B_tiles = [ + [ + B[ + j * block_n : min((j + 1) * block_n, N), + i * block_k : min((i + 1) * block_k, K), + ] + for i in range(k_tiles) + ] + for j in range(n_tiles) + ] + C_tiles = [C[:, j * block_n : min((j + 1) * block_n, N)] for j in range(n_tiles)] + As_tiles = [As[:, i : i + 1] for i in range(k_tiles)] + + for i in range(k_tiles): + for j in range(n_tiles): + a = A_tiles[i] + b = B_tiles[j][i] + c = C_tiles[j] + s = As_tiles[i] * Bs[j][i] + c[:, :] += torch.matmul(a, b.t()) * s + + C = C.reshape(origin_C_shape).to(output_dtype) + return C + + +# For test +def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape): + """This function performs fused moe with block-wise quantization using native torch.""" + + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + _, block_k = block_shape[0], block_shape[1] + a_q, a_s = native_per_token_group_quant_int8(a, block_k) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + inter_out = native_w8a8_block_int8_matmul( + a_q[mask], w1[i], a_s[mask], w1_s[i], block_shape, output_dtype=a.dtype + ) + act_out = SiluAndMul().forward_native(inter_out) + act_out_q, act_out_s = native_per_token_group_quant_int8(act_out, block_k) + act_out = act_out.to(torch.float32) + out[mask] = native_w8a8_block_int8_matmul( + act_out_q, w2[i], act_out_s, w2_s[i], block_shape, output_dtype=a.dtype + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +class TestW8A8BlockINT8FusedMoE(unittest.TestCase): + DTYPES = [torch.half, torch.bfloat16] + M = [1, 33, 64, 222] + N = [128, 1024] + K = [256, 4096] + E = [8, 24] + TOP_KS = [2, 6] + # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + + @classmethod + def setUpClass(cls): + if not torch.cuda.is_available(): + raise unittest.SkipTest("CUDA is not available") + torch.set_default_device("cuda") + + def _w8a8_block_int8_fused_moe(self, M, N, K, E, topk, block_size, dtype, seed): + torch.manual_seed(seed) + # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half + factor_for_scale = 1e-2 + int8_info = torch.iinfo(torch.int8) + int8_max, int8_min = int8_info.max, int8_info.min + + a = torch.randn((M, K), dtype=dtype) / 10 + + w1_fp32 = (torch.rand((E, 2 * N, K), dtype=torch.float32) - 0.5) * 2 * int8_max + w1 = w1_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) + + w2_fp32 = (torch.rand((E, K, N), dtype=torch.float32) - 0.5) * 2 * int8_max + w2 = w2_fp32.clamp(min=int8_min, max=int8_max).to(torch.int8) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = (2 * N + block_n - 1) // block_n + n_tiles_w2 = (K + block_n - 1) // block_n + k_tiles_w1 = (K + block_k - 1) // block_k + k_tiles_w2 = (N + block_k - 1) // block_k + + w1_s = ( + torch.rand((E, n_tiles_w1, k_tiles_w1), dtype=torch.float32) + * factor_for_scale + ) + w2_s = ( + torch.rand((E, n_tiles_w2, k_tiles_w2), dtype=torch.float32) + * factor_for_scale + ) + + score = torch.randn((M, E), dtype=dtype) + + with torch.inference_mode(): + out = fused_moe( + a, + w1, + w2, + score, + topk, + renormalize=False, + use_int8_w8a8=True, + w1_scale=w1_s, + w2_scale=w2_s, + block_shape=block_size, + ) + ref_out = torch_w8a8_block_int8_moe( + a, w1, w2, w1_s, w2_s, score, topk, block_size + ) + + self.assertTrue( + torch.mean(torch.abs(out.to(torch.float32) - ref_out.to(torch.float32))) + / torch.mean(torch.abs(ref_out.to(torch.float32))) + < 0.02 + ) + + def test_w8a8_block_int8_fused_moe(self): + for params in itertools.product( + self.M, + self.N, + self.K, + self.E, + self.TOP_KS, + self.BLOCK_SIZE, + self.DTYPES, + self.SEEDS, + ): + with self.subTest( + M=params[0], + N=params[1], + K=params[2], + E=params[3], + topk=params[4], + block_size=params[5], + dtype=params[6], + seed=params[7], + ): + self._w8a8_block_int8_fused_moe(*params) + + +if __name__ == "__main__": + unittest.main(verbosity=2)