diff --git a/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py b/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py index 7151f861f..f93732154 100644 --- a/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py +++ b/benchmark/kernels/deepseek/benchmark_deepgemm_fp8_gemm.py @@ -10,7 +10,9 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import ( w8a8_block_fp8_matmul as vllm_w8a8_block_fp8_matmul, ) -from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul +from sglang.srt.layers.quantization.fp8_kernel import ( + w8a8_block_fp8_matmul_deepgemm as w8a8_block_fp8_matmul, +) # Adapted from https://github.com/tile-ai/tilelang/blob/a8cfdce92795cb861c9033573534653ee040b5ed/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py#L1 diff --git a/python/sglang/srt/layers/quantization/fp8.py b/python/sglang/srt/layers/quantization/fp8.py index 1f961ed68..3c8ddf44c 100644 --- a/python/sglang/srt/layers/quantization/fp8.py +++ b/python/sglang/srt/layers/quantization/fp8.py @@ -49,8 +49,8 @@ from sglang.srt.layers.quantization.fp8_kernel import ( ) from sglang.srt.layers.quantization.fp8_utils import ( apply_fp8_linear, - apply_w8a8_block_fp8_linear, cutlass_fp8_supported, + dispatch_w8a8_block_fp8_linear, input_to_float8, is_sm100_supported, normalize_e4m3fn_to_e4m3fnuz, @@ -209,6 +209,8 @@ class Fp8LinearMethod(LinearMethodBase): # Marlin doesn't support block-wise fp8 self.use_marlin = False + self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear() + def create_weights( self, layer: torch.nn.Module, @@ -417,7 +419,7 @@ class Fp8LinearMethod(LinearMethodBase): ) if self.block_quant: - return apply_w8a8_block_fp8_linear( + return self.w8a8_block_fp8_linear( input=x, weight=layer.weight, block_size=self.quant_config.weight_block_size, diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 226986612..f92ff801f 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -740,7 +740,59 @@ if _is_hip: return _w8a8_block_fp8_matmul -def w8a8_block_fp8_matmul( +def prepare_block_fp8_matmul_inputs( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> Tuple[int, int, int]: + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] + assert A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 + assert B.is_contiguous() + assert Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N,) + C = A.new_empty(C_shape, dtype=output_dtype) + + return M, N, K, C + + +def w8a8_block_fp8_matmul_deepgemm( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype, +) -> torch.Tensor: + M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype) + + # Deepgemm only supports output tensor type as bfloat16 + assert C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM + + if supports_custom_op(): + torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) + else: + deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C) + + return C + + +def w8a8_block_fp8_matmul_triton( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, @@ -764,81 +816,81 @@ def w8a8_block_fp8_matmul( Returns: torch.Tensor: The result of matmul. """ - assert len(block_size) == 2 - block_n, block_k = block_size[0], block_size[1] - assert A.shape[-1] == B.shape[-1] - assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() - assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] - M = A.numel() // A.shape[-1] + M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype) - assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 - N, K = B.shape - assert triton.cdiv(N, block_n) == Bs.shape[0] - assert triton.cdiv(K, block_k) == Bs.shape[1] + block_n, block_k = block_size - C_shape = A.shape[:-1] + (N,) - C = A.new_empty(C_shape, dtype=output_dtype) - - # deepgemm only support bf16 - if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM: - if supports_custom_op(): - torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C) - else: - deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C) + configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: - configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) - if configs: - # If an optimal configuration map has been found, look up the - # optimal config - config = configs[min(configs.keys(), key=lambda x: abs(x - M))] - else: - # Default config - # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1] - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": block_size[0], - "BLOCK_SIZE_K": block_size[1], - "GROUP_SIZE_M": 32, - "num_warps": 4, - "num_stages": 3, - } + # Default config + # Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 3, + } - def grid(META): - return ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) - * triton.cdiv(N, META["BLOCK_SIZE_N"]), - ) - - kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config) - - kernel[grid]( - A, - B, - C, - As, - Bs, - M, - N, - K, - block_n, - block_k, - A.stride(-2), - A.stride(-1), - B.stride(1), - B.stride(0), - C.stride(-2), - C.stride(-1), - As.stride(-2), - As.stride(-1), - Bs.stride(1), - Bs.stride(0), - **config, + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config) + + kernel[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + return C +# universal entry point, for testing purposes +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + if output_dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM: + return w8a8_block_fp8_matmul_deepgemm( + A, B, As, Bs, block_size, output_dtype=output_dtype + ) + + return w8a8_block_fp8_matmul_triton( + A, B, As, Bs, block_size, output_dtype=output_dtype + ) + + @triton.jit def _per_tensor_quant_mla_fp8_stage1( x_ptr, diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index f38284389..78816b1d9 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -1,5 +1,6 @@ import os -from typing import List, Optional, Tuple +from curses import flash +from typing import Callable, List, Optional, Tuple import torch @@ -21,7 +22,8 @@ from sglang.srt.layers.quantization.fp8_kernel import ( scaled_fp8_quant, sglang_per_token_quant_fp8, static_quant_fp8, - w8a8_block_fp8_matmul, + w8a8_block_fp8_matmul_deepgemm, + w8a8_block_fp8_matmul_triton, ) from sglang.srt.utils import ( get_bool_env_var, @@ -134,7 +136,20 @@ if ENABLE_FLASHINFER_GEMM: from flashinfer.gemm import gemm_fp8_nt_groupwise -def apply_w8a8_block_fp8_linear( +def dispatch_w8a8_block_fp8_linear() -> Callable: + if ENABLE_FLASHINFER_GEMM: + return flashinfer_gemm_w8a8_block_fp8_linear + elif CUTLASS_BLOCK_FP8_SUPPORTED: + return cutlass_w8a8_block_fp8_linear_with_fallback + elif _is_hip and use_aiter_moe: + return aiter_w8a8_block_fp8_linear + elif _ENABLE_JIT_DEEPGEMM: + return deepgemm_w8a8_block_fp8_linear_with_fallback + else: + return triton_w8a8_block_fp8_linear + + +def flashinfer_gemm_w8a8_block_fp8_linear( input: torch.Tensor, weight: torch.Tensor, block_size: List[int], @@ -143,58 +158,148 @@ def apply_w8a8_block_fp8_linear( bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: assert input_scale is None - # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - # TODO: add more robust shape check here - shape_supported_by_cutlass = ( - weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0 + + q_input, x_scale = sglang_per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False + ) + + x_scale_input = x_scale.T.contiguous() + weight_scale_input = weight_scale.T.contiguous() + + output = gemm_fp8_nt_groupwise( + q_input, weight, x_scale_input, weight_scale_input, out_dtype=input_2d.dtype ) - if ENABLE_FLASHINFER_GEMM: - q_input, x_scale = sglang_per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=False - ) - x_scale_input = x_scale.T.contiguous() - weight_scale_input = weight_scale.T.contiguous() - output = gemm_fp8_nt_groupwise( - q_input, weight, x_scale_input, weight_scale_input, out_dtype=input.dtype - ) - elif CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=True - ) - output = fp8_blockwise_scaled_mm( - q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype - ) - elif _is_hip and use_aiter_moe: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=False - ) - output = torch.zeros( - [q_input.shape[0], weight.shape[0]], - dtype=input.dtype, - device=q_input.device, - ) - gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output) - else: - if _ENABLE_JIT_DEEPGEMM: - q_input, x_scale = sglang_per_token_group_quant_fp8( - input_2d, - block_size[1], - column_major_scales=True, - scale_tma_aligned=True, - ) - else: - q_input, x_scale = per_token_group_quant_fp8( - input_2d, block_size[1], column_major_scales=False - ) - output = w8a8_block_fp8_matmul( - q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype - ) if bias is not None: - output = output + bias - return output.to(dtype=input.dtype).view(*output_shape) + output += bias + + return output.to(dtype=input_2d.dtype).view(*output_shape) + + +def cutlass_w8a8_block_fp8_linear_with_fallback( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert input_scale is None + + # TODO: add more robust shape check here + shape_supported = weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0 + + if not shape_supported: + # fallback to triton + return triton_w8a8_block_fp8_linear( + input, weight, block_size, weight_scale, input_scale, bias + ) + + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=True + ) + output = fp8_blockwise_scaled_mm( + q_input, weight.T, x_scale, weight_scale.T, out_dtype=input_2d.dtype + ) + if bias is not None: + output += bias + return output.to(dtype=input_2d.dtype).view(*output_shape) + + +def deepgemm_w8a8_block_fp8_linear_with_fallback( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert input_scale is None + + output_dtype = input.dtype + dtype_supported = output_dtype == torch.bfloat16 + + # TODO: add more robust shape check here + shape_supported = weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0 + + if not (shape_supported and dtype_supported): + # fall back to triton + return triton_w8a8_block_fp8_linear( + input, weight, block_size, weight_scale, input_scale, bias + ) + + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + q_input, x_scale = sglang_per_token_group_quant_fp8( + input_2d, + block_size[1], + column_major_scales=True, + scale_tma_aligned=True, + ) + output = w8a8_block_fp8_matmul_deepgemm( + q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype + ) + if bias is not None: + output += bias + return output.to(dtype=output_dtype).view(*output_shape) + + +def aiter_w8a8_block_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert input_scale is None + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False + ) + output = torch.zeros( + [q_input.shape[0], weight.shape[0]], + dtype=input_2d.dtype, + device=q_input.device, + ) + gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output) + + if bias is not None: + output += bias + + return output.to(dtype=input_2d.dtype).view(*output_shape) + + +def triton_w8a8_block_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + assert input_scale is None + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False + ) + output = w8a8_block_fp8_matmul_triton( + q_input, weight, x_scale, weight_scale, block_size, output_dtype=input_2d.dtype + ) + if bias is not None: + output += bias + return output.to(dtype=input_2d.dtype).view(*output_shape) def input_to_float8( diff --git a/sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py b/sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py index 60b2133e4..ed0410298 100644 --- a/sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py +++ b/sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py @@ -9,7 +9,9 @@ from deep_gemm import get_col_major_tma_aligned_tensor from sgl_kernel import fp8_blockwise_scaled_mm from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm -from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul +from sglang.srt.layers.quantization.fp8_kernel import ( + w8a8_block_fp8_matmul_triton as w8a8_block_fp8_matmul, +) def get_weight_shapes(args):