diff --git a/python/pyproject.toml b/python/pyproject.toml index 296440016..34ada35dc 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -25,7 +25,7 @@ runtime_common = [ ] srt = [ "sglang[runtime_common]", "cuda-python", - "sgl-kernel>=0.0.3.post4", "torch", "vllm>=0.6.4.post1,<=0.7.2", + "sgl-kernel>=0.0.3.post5", "torch", "vllm>=0.6.4.post1,<=0.7.2", "flashinfer_python>=0.2.0.post2", "outlines>=0.0.44,<=0.1.11" ] diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 8ff18715f..47f310a24 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -76,11 +76,60 @@ def _per_token_group_quant_fp8( tl.store(y_s_ptr, y_s) +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + def per_token_group_quant_fp8( x: torch.Tensor, group_size: int, eps: float = 1e-10, dtype: torch.dtype = fp8_type_, + column_major_scales: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """Function to perform per-token-group quantization on an input tensor `x`. @@ -112,29 +161,52 @@ def per_token_group_quant_fp8( x_q = torch.empty_like(x, device=x.device, dtype=dtype) M = x.numel() // group_size N = group_size - x_s = torch.empty( - x.shape[:-1] + (x.shape[-1] // group_size,), - device=x.device, - dtype=torch.float32, - ) + if column_major_scales: + x_s = torch.empty( + (x.shape[-1] // group_size,) + x.shape[:-1], + device=x.device, + dtype=torch.float32, + ).permute(-1, -2) + else: + x_s = torch.empty( + x.shape[:-1] + (x.shape[-1] // group_size,), + device=x.device, + dtype=torch.float32, + ) BLOCK = triton.next_power_of_2(N) # heuristics for number of warps num_warps = min(max(BLOCK // 256, 1), 8) num_stages = 1 - _per_token_group_quant_fp8[(M,)]( - x, - x_q, - x_s, - group_size, - N, - eps, - fp8_min=fp8_min, - fp8_max=fp8_max, - BLOCK=BLOCK, - num_warps=num_warps, - num_stages=num_stages, - ) + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M,)]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M,)]( + x, + x_q, + x_s, + group_size, + N, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) return x_q, x_s diff --git a/python/sglang/srt/layers/quantization/fp8_utils.py b/python/sglang/srt/layers/quantization/fp8_utils.py index d6ff12ee1..a613f8a38 100644 --- a/python/sglang/srt/layers/quantization/fp8_utils.py +++ b/python/sglang/srt/layers/quantization/fp8_utils.py @@ -10,6 +10,9 @@ from sglang.srt.layers.quantization.fp8_kernel import ( from sglang.srt.utils import is_hip is_hip_ = is_hip() +_is_cuda = torch.cuda.is_available() and torch.version.cuda +if _is_cuda: + from sgl_kernel import fp8_blockwise_scaled_mm def normalize_e4m3fn_to_e4m3fnuz( @@ -36,6 +39,19 @@ def normalize_e4m3fn_to_e4m3fnuz( return weight, weight_scale, input_scale +def cutlass_block_fp8_supported() -> bool: + if _is_cuda: + major, minor = torch.cuda.get_device_capability() + sm_version = major * 10 + minor + cuda_version = tuple(map(int, torch.version.cuda.split("."))) + if cuda_version >= (12, 0) and sm_version >= 90: + return True + return False + + +CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported() + + def apply_w8a8_block_fp8_linear( input: torch.Tensor, weight: torch.Tensor, @@ -48,11 +64,24 @@ def apply_w8a8_block_fp8_linear( # View input as 2D matrix for fp8 methods input_2d = input.view(-1, input.shape[-1]) output_shape = [*input.shape[:-1], weight.shape[0]] - - q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1]) - output = w8a8_block_fp8_matmul( - q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype + # TODO: add more robust shape check here + shape_supported_by_cutlass = ( + weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0 ) + if CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=True + ) + output = fp8_blockwise_scaled_mm( + q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype + ) + else: + q_input, x_scale = per_token_group_quant_fp8( + input_2d, block_size[1], column_major_scales=False + ) + output = w8a8_block_fp8_matmul( + q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype + ) if bias is not None: output = output + bias