# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Adapted from https://github.com/sgl-project/sglang/pull/2575 import functools import json import os from typing import Any, Callable, Optional, Union, List import torch import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_dequantize) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( CUTLASS_BLOCK_FP8_SUPPORTED) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton from vllm.utils import cdiv, direct_register_custom_op, has_deep_gemm logger = init_logger(__name__) def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: if isinstance(x, torch.Tensor): x = x.dtype return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz def cutlass_scaled_mm( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, block_size: List[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: return ops.cutlass_scaled_mm(A, B.T, out_dtype=output_dtype, scale_a=As, scale_b=Bs.T) def rocm_aiter_gemm_w8a8_blockscale_impl( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, block_size: List[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: import aiter as rocm_aiter return rocm_aiter.gemm_a8w8_blockscale_CK(A, B, As, Bs, dtype=output_dtype) def rocm_aiter_gemm_w8a8_blockscale_fake( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, block_size: List[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: m = A.shape[0] n = B.shape[0] Y = torch.empty(m, n, dtype=output_dtype, device=A.device) return Y if current_platform.is_rocm(): direct_register_custom_op( op_name="rocm_aiter_gemm_w8a8_blockscale", op_func=rocm_aiter_gemm_w8a8_blockscale_impl, mutates_args=[], fake_impl=rocm_aiter_gemm_w8a8_blockscale_fake, dispatch_key=current_platform.dispatch_key, ) def dispatch_w8a8_blockscale_func( use_cutlass: bool, use_aiter_and_is_supported: bool ) -> Callable[[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int], torch.dtype, ], torch.Tensor]: if use_cutlass: return cutlass_scaled_mm if (use_aiter_and_is_supported): return torch.ops.vllm.rocm_aiter_gemm_w8a8_blockscale return w8a8_block_fp8_matmul def should_use_deepgemm(output_dtype: torch.dtype, weight: torch.Tensor): """ Check if DeepGEMM should be used based on the output dtype and weight shape. DeepGEMM is only supported for bfloat16 output dtype and weights with shape divisible by 128. """ return (current_platform.is_cuda() and current_platform.is_device_capability(90) and has_deep_gemm() and envs.VLLM_USE_DEEP_GEMM and output_dtype == torch.bfloat16 and weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) # TODO fix ROCm->Triton custom path: # https://github.com/vllm-project/vllm/issues/14397 def apply_w8a8_block_fp8_linear( input: torch.Tensor, weight: torch.Tensor, block_size: List[int], weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, use_aiter_and_is_supported: bool = False, ) -> torch.Tensor: assert input_scale is None # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] output_dtype = input.dtype if should_use_deepgemm(output_dtype, weight): input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=True, ) import vllm.model_executor.layers.quantization.deepgemm # noqa: F401 output = torch.ops.vllm.w8a8_block_fp8_matmul_deepgemm( q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype) if bias is not None: output += bias return output.to(dtype=output_dtype).view(*output_shape) if current_platform.is_cuda(): if current_platform.has_device_capability(100): use_cutlass = cutlass_block_fp8_supported and ( cdiv(weight.shape[0], 128) == weight_scale.shape[0] and cdiv(weight.shape[1], 128) == weight_scale.shape[1]) else: # TODO: update this after switching to public sm90 block scale gemm # as it also supports weight.shape % 128 != 0 use_cutlass = cutlass_block_fp8_supported and ( weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0) else: use_cutlass = False w8a8_blockscale_func = dispatch_w8a8_blockscale_func( use_cutlass, use_aiter_and_is_supported) if use_cutlass: q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=use_cutlass) output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, block_size, input.dtype) else: q_input, x_scale = per_token_group_quant_fp8( input_2d, block_size[1], column_major_scales=use_cutlass) output = w8a8_blockscale_func(q_input, weight, x_scale, weight_scale, block_size, input.dtype) if bias is not None: output = output + bias return output.to(dtype=input.dtype).view(*output_shape) def apply_w8a8_block_fp8_linear_fake( input: torch.Tensor, weight: torch.Tensor, block_size: List[int], weight_scale: torch.Tensor, input_scale: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, use_aiter_and_is_supported: bool = False, ) -> torch.Tensor: output_shape = [*input.shape[:-1], weight.shape[0]] return torch.empty(output_shape, dtype=input.dtype, device=input.device) if not current_platform.is_cpu(): direct_register_custom_op( op_name="apply_w8a8_block_fp8_linear", op_func=apply_w8a8_block_fp8_linear, mutates_args=[], fake_impl=apply_w8a8_block_fp8_linear_fake, ) def input_to_float8( x: torch.Tensor, dtype: Optional[torch.dtype] = None ) -> tuple[torch.Tensor, torch.Tensor]: """This function quantizes input values to float8 values " "with tensor-wise quantization.""" dtype = current_platform.fp8_dtype() if dtype is None else dtype finfo = torch.finfo(dtype) min_val, max_val = x.aminmax() amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) scale = finfo.max / amax x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() def block_quant_to_tensor_quant( x_q_block: torch.Tensor, x_s: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: """This function converts block-wise quantization to tensor-wise quantization. The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale and the block size. The outputs are tensor-wise quantization tensor and tensor-wise quantization scale. Note only float8 is supported for now. """ x_dq_block = scaled_dequantize(x_q_block, x_s) x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) return x_q_tensor, scale @triton.jit def _per_token_group_quant_fp8( # Pointers to inputs and output y_ptr, y_q_ptr, y_s_ptr, group_size, # Num columns of y y_num_columns, y_row_stride, # Avoid to divide zero eps, # Information for float8 fp8_min, fp8_max, # Meta-parameters BLOCK: tl.constexpr, ): """A Triton-accelerated function to perform per-token-group quantization on a tensor. This function converts the tensor values into float8 values. """ groups_per_row = y_num_columns // group_size # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) row = g_id // groups_per_row row_g_id = g_id % groups_per_row # Ensure offset calculations use int64 to prevent overflow y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) * group_size) y_ptr += y_ptr_offset y_q_ptr_offset = g_id.to(tl.int64) * group_size y_q_ptr += y_q_ptr_offset y_s_ptr += g_id cols = tl.arange(0, BLOCK) # N <= BLOCK mask = cols < group_size y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) y_s = _absmax / fp8_max y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) @triton.jit def _per_token_group_quant_fp8_colmajor( # Pointers to inputs and output y_ptr, y_q_ptr, y_s_ptr, group_size, # Num columns of y y_num_columns, y_row_stride, # Stride from one column to the next of y_s y_s_col_stride, # Avoid to divide zero eps, # Information for float8 fp8_min, fp8_max, # Meta-parameters BLOCK: tl.constexpr, ): """A Triton-accelerated function to perform per-token-group quantization on a tensor. This function converts the tensor values into float8 values. """ groups_per_row = y_num_columns // group_size # Map the program id to the row of X and Y it should compute. g_id = tl.program_id(0) row = g_id // groups_per_row row_g_id = g_id % groups_per_row # Ensure offset calculations use int64 to prevent overflow y_ptr_offset = (row.to(tl.int64) * y_row_stride) + (row_g_id.to(tl.int64) * group_size) y_ptr += y_ptr_offset y_q_ptr_offset = g_id.to(tl.int64) * group_size y_q_ptr += y_q_ptr_offset # Convert g_id the flattened block coordinate to 2D so we can index # into the output y_scales matrix blocks_per_row = y_num_columns // group_size scale_col = g_id % blocks_per_row scale_row = g_id // blocks_per_row # Ensure offset calculation uses int64 for y_s_ptr y_s_ptr_offset = (scale_col.to(tl.int64) * y_s_col_stride) + scale_row.to( tl.int64) y_s_ptr += y_s_ptr_offset cols = tl.arange(0, BLOCK) # group_size <= BLOCK mask = cols < group_size y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) # Quant _absmax = tl.maximum(tl.max(tl.abs(y)), eps) y_s = _absmax / fp8_max y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) tl.store(y_q_ptr + cols, y_q, mask=mask) tl.store(y_s_ptr, y_s) def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, eps: float = 1e-10, dtype: Optional[torch.dtype] = None, column_major_scales: bool = False, out_q: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. It converts the tensor values into signed float8 values and returns the quantized tensor along with the scaling factor used for quantization. Args: x: The input tensor with ndim >= 2. group_size: The group size used for quantization. eps: The minimum to avoid dividing zero. dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now. column_major_scales: Outputs scales in column major. out_q: Optional output tensor. If not provided, function will create. Returns: tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization. """ dtype = current_platform.fp8_dtype() if dtype is None else dtype assert (x.shape[-1] % group_size == 0), ( f"the last dimension of `x` {x.shape[-1]} must be divisible " f"by `group_size` {group_size}") assert x.stride(-1) == 1, "`x` groups must be contiguous" finfo = torch.finfo(dtype) fp8_min = finfo.min fp8_max = finfo.max assert out_q is None or out_q.shape == x.shape x_q = out_q if x_q is None: x_q = torch.empty_like(x, device=x.device, dtype=dtype) M = x.numel() // group_size N = group_size if column_major_scales: shape = (x.shape[-1] // group_size, ) + x.shape[:-1] x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) else: shape = x.shape[:-1] + (x.shape[-1] // group_size, ) x_s = torch.empty(shape, device=x.device, dtype=torch.float32) BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 if column_major_scales: _per_token_group_quant_fp8_colmajor[(M, )]( x, x_q, x_s, group_size, x.shape[1], x.stride(0), x_s.stride(1), eps, fp8_min=fp8_min, fp8_max=fp8_max, BLOCK=BLOCK, num_warps=num_warps, num_stages=num_stages, ) else: _per_token_group_quant_fp8[(M, )]( x, x_q, x_s, group_size, x.shape[1], x.stride(0), eps, fp8_min=fp8_min, fp8_max=fp8_max, BLOCK=BLOCK, num_warps=num_warps, num_stages=num_stages, ) return x_q, x_s @triton.jit def _w8a8_block_fp8_matmul( # Pointers to inputs and output A, B, C, As, Bs, # Shape for matmul M, N, K, # Block size for block-wise quantization group_n, group_k, # Stride for inputs and output stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_As_m, stride_As_k, stride_Bs_k, stride_Bs_n, # Meta-parameters BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, ): """Triton-accelerated function used to perform linear operations (dot product) on input tensors `A` and `B` with block-wise quantization, and store the result in output tensor `C`. """ pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_in_group = GROUP_SIZE_M * num_pid_n group_id = pid // num_pid_in_group first_pid_m = group_id * GROUP_SIZE_M group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) pid_m = first_pid_m + (pid % group_size_m) pid_n = (pid % num_pid_in_group) // group_size_m offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) As_ptrs = As + offs_am * stride_As_m offs_bsn = offs_bn // group_n Bs_ptrs = Bs + offs_bsn * stride_Bs_n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) k_start = k * BLOCK_SIZE_K offs_ks = k_start // group_k a_s = tl.load(As_ptrs + offs_ks * stride_As_k) b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] a_ptrs += BLOCK_SIZE_K * stride_ak b_ptrs += BLOCK_SIZE_K * stride_bk if C.dtype.element_ty == tl.bfloat16: c = accumulator.to(tl.bfloat16) elif C.dtype.element_ty == tl.float16: c = accumulator.to(tl.float16) else: c = accumulator.to(tl.float32) offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) tl.store(c_ptrs, c, mask=c_mask) @functools.lru_cache def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, block_k: int) -> Optional[dict[int, Any]]: """ Return optimized configurations for the w8a8 block fp8 kernel. The return value will be a dictionary that maps an irregular grid of batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the kernel on a given batch size bs, the closest batch size in the grid should be picked and the associated configuration chosen to invoke the kernel. """ # First look up if an optimized configuration is available in the configs # directory device_name = current_platform.get_device_name().replace(" ", "_") json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501 config_file_path = os.path.join( os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) if os.path.exists(config_file_path): with open(config_file_path) as f: logger.info( "Using configuration from %s for W8A8 Block FP8 kernel.", config_file_path, ) # If a configuration has been found, return it return {int(key): val for key, val in json.load(f).items()} # If no optimized configuration is available, we will use the default # configuration logger.warning( "Using default W8A8 Block FP8 kernel config. Performance might " "be sub-optimal! Config file not found at %s", config_file_path, ) return None def w8a8_block_fp8_matmul( A: torch.Tensor, B: torch.Tensor, As: torch.Tensor, Bs: torch.Tensor, block_size: list[int], output_dtype: torch.dtype = torch.float16, ) -> torch.Tensor: """This function performs matrix multiplication with block-wise quantization. It takes two input tensors `A` and `B` with scales `As` and `Bs`. The output is returned in the specified `output_dtype`. Args: A: The input tensor, e.g., activation. B: The input tensor, e.g., weight. As: The per-token-group quantization scale for `A`. Bs: The per-block quantization scale for `B`. block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128]. output_dytpe: The dtype of the returned tensor. Returns: torch.Tensor: The result of matmul. """ assert len(block_size) == 2 block_n, block_k = block_size[0], block_size[1] assert A.shape[-1] == B.shape[-1] assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] M = A.numel() // A.shape[-1] assert B.ndim == 2 and Bs.ndim == 2 N, K = B.shape assert triton.cdiv(N, block_n) == Bs.shape[0] assert triton.cdiv(K, block_k) == Bs.shape[1] C_shape = A.shape[:-1] + (N, ) C = A.new_empty(C_shape, dtype=output_dtype) configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) if configs: # Get the optimal config if there is one config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Default config # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0] # BLOCK_SIZE_K must be divisible by block_size[1] config = { "BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": block_size[0], "BLOCK_SIZE_K": block_size[1], "GROUP_SIZE_M": 32, "num_warps": 4, "num_stages": 2, } def grid(META): return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), ) _w8a8_block_fp8_matmul[grid]( A, B, C, As, Bs, M, N, K, block_n, block_k, A.stride(-2), A.stride(-1), B.stride(1), B.stride(0), C.stride(-2), C.stride(-1), As.stride(-2), As.stride(-1), Bs.stride(1), Bs.stride(0), **config, ) return C