refactor apply_w8a8_block_fp8_linear in fp (#6545)
This commit is contained in:
@@ -49,8 +49,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
)
|
||||
from sglang.srt.layers.quantization.fp8_utils import (
|
||||
apply_fp8_linear,
|
||||
apply_w8a8_block_fp8_linear,
|
||||
cutlass_fp8_supported,
|
||||
dispatch_w8a8_block_fp8_linear,
|
||||
input_to_float8,
|
||||
is_sm100_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
@@ -209,6 +209,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
# Marlin doesn't support block-wise fp8
|
||||
self.use_marlin = False
|
||||
|
||||
self.w8a8_block_fp8_linear = dispatch_w8a8_block_fp8_linear()
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -417,7 +419,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
if self.block_quant:
|
||||
return apply_w8a8_block_fp8_linear(
|
||||
return self.w8a8_block_fp8_linear(
|
||||
input=x,
|
||||
weight=layer.weight,
|
||||
block_size=self.quant_config.weight_block_size,
|
||||
|
||||
@@ -740,7 +740,59 @@ if _is_hip:
|
||||
return _w8a8_block_fp8_matmul
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul(
|
||||
def prepare_block_fp8_matmul_inputs(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> Tuple[int, int, int]:
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1]
|
||||
assert A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
|
||||
M = A.numel() // A.shape[-1]
|
||||
|
||||
assert B.ndim == 2
|
||||
assert B.is_contiguous()
|
||||
assert Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
return M, N, K, C
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul_deepgemm(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
|
||||
|
||||
# Deepgemm only supports output tensor type as bfloat16
|
||||
assert C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM
|
||||
|
||||
if supports_custom_op():
|
||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||
else:
|
||||
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
def w8a8_block_fp8_matmul_triton(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
@@ -764,81 +816,81 @@ def w8a8_block_fp8_matmul(
|
||||
Returns:
|
||||
torch.Tensor: The result of matmul.
|
||||
"""
|
||||
assert len(block_size) == 2
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
|
||||
assert A.shape[-1] == B.shape[-1]
|
||||
assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
|
||||
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
|
||||
M = A.numel() // A.shape[-1]
|
||||
M, N, K, C = prepare_block_fp8_matmul_inputs(A, B, As, Bs, block_size, output_dtype)
|
||||
|
||||
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
|
||||
N, K = B.shape
|
||||
assert triton.cdiv(N, block_n) == Bs.shape[0]
|
||||
assert triton.cdiv(K, block_k) == Bs.shape[1]
|
||||
block_n, block_k = block_size
|
||||
|
||||
C_shape = A.shape[:-1] + (N,)
|
||||
C = A.new_empty(C_shape, dtype=output_dtype)
|
||||
|
||||
# deepgemm only support bf16
|
||||
if C.dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
|
||||
if supports_custom_op():
|
||||
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
|
||||
else:
|
||||
deep_gemm_gemm_nt_f8f8bf16((A, As), (B, Bs), C)
|
||||
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1])
|
||||
if configs:
|
||||
# If an optimal configuration map has been found, look up the
|
||||
# optimal config
|
||||
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
|
||||
else:
|
||||
# Default config
|
||||
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_size[0],
|
||||
"BLOCK_SIZE_K": block_size[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
# Default config
|
||||
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_size[1]
|
||||
config = {
|
||||
"BLOCK_SIZE_M": 64,
|
||||
"BLOCK_SIZE_N": block_size[0],
|
||||
"BLOCK_SIZE_K": block_size[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 3,
|
||||
}
|
||||
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"])
|
||||
* triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
def grid(META):
|
||||
return (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
|
||||
kernel = select_w8a8_block_fp8_matmul_kernel(M, N, config)
|
||||
|
||||
kernel[grid](
|
||||
A,
|
||||
B,
|
||||
C,
|
||||
As,
|
||||
Bs,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
block_n,
|
||||
block_k,
|
||||
A.stride(-2),
|
||||
A.stride(-1),
|
||||
B.stride(1),
|
||||
B.stride(0),
|
||||
C.stride(-2),
|
||||
C.stride(-1),
|
||||
As.stride(-2),
|
||||
As.stride(-1),
|
||||
Bs.stride(1),
|
||||
Bs.stride(0),
|
||||
**config,
|
||||
)
|
||||
|
||||
return C
|
||||
|
||||
|
||||
# universal entry point, for testing purposes
|
||||
def w8a8_block_fp8_matmul(
|
||||
A: torch.Tensor,
|
||||
B: torch.Tensor,
|
||||
As: torch.Tensor,
|
||||
Bs: torch.Tensor,
|
||||
block_size: List[int],
|
||||
output_dtype: torch.dtype = torch.float16,
|
||||
) -> torch.Tensor:
|
||||
if output_dtype == torch.bfloat16 and _ENABLE_JIT_DEEPGEMM:
|
||||
return w8a8_block_fp8_matmul_deepgemm(
|
||||
A, B, As, Bs, block_size, output_dtype=output_dtype
|
||||
)
|
||||
|
||||
return w8a8_block_fp8_matmul_triton(
|
||||
A, B, As, Bs, block_size, output_dtype=output_dtype
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _per_tensor_quant_mla_fp8_stage1(
|
||||
x_ptr,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
from curses import flash
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -21,7 +22,8 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
scaled_fp8_quant,
|
||||
sglang_per_token_quant_fp8,
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
w8a8_block_fp8_matmul_deepgemm,
|
||||
w8a8_block_fp8_matmul_triton,
|
||||
)
|
||||
from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
@@ -134,7 +136,20 @@ if ENABLE_FLASHINFER_GEMM:
|
||||
from flashinfer.gemm import gemm_fp8_nt_groupwise
|
||||
|
||||
|
||||
def apply_w8a8_block_fp8_linear(
|
||||
def dispatch_w8a8_block_fp8_linear() -> Callable:
|
||||
if ENABLE_FLASHINFER_GEMM:
|
||||
return flashinfer_gemm_w8a8_block_fp8_linear
|
||||
elif CUTLASS_BLOCK_FP8_SUPPORTED:
|
||||
return cutlass_w8a8_block_fp8_linear_with_fallback
|
||||
elif _is_hip and use_aiter_moe:
|
||||
return aiter_w8a8_block_fp8_linear
|
||||
elif _ENABLE_JIT_DEEPGEMM:
|
||||
return deepgemm_w8a8_block_fp8_linear_with_fallback
|
||||
else:
|
||||
return triton_w8a8_block_fp8_linear
|
||||
|
||||
|
||||
def flashinfer_gemm_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
@@ -143,58 +158,148 @@ def apply_w8a8_block_fp8_linear(
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
# View input as 2D matrix for fp8 methods
|
||||
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
# TODO: add more robust shape check here
|
||||
shape_supported_by_cutlass = (
|
||||
weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
|
||||
|
||||
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
|
||||
x_scale_input = x_scale.T.contiguous()
|
||||
weight_scale_input = weight_scale.T.contiguous()
|
||||
|
||||
output = gemm_fp8_nt_groupwise(
|
||||
q_input, weight, x_scale_input, weight_scale_input, out_dtype=input_2d.dtype
|
||||
)
|
||||
if ENABLE_FLASHINFER_GEMM:
|
||||
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
x_scale_input = x_scale.T.contiguous()
|
||||
weight_scale_input = weight_scale.T.contiguous()
|
||||
output = gemm_fp8_nt_groupwise(
|
||||
q_input, weight, x_scale_input, weight_scale_input, out_dtype=input.dtype
|
||||
)
|
||||
elif CUTLASS_BLOCK_FP8_SUPPORTED and shape_supported_by_cutlass:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=True
|
||||
)
|
||||
output = fp8_blockwise_scaled_mm(
|
||||
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
||||
)
|
||||
elif _is_hip and use_aiter_moe:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
output = torch.zeros(
|
||||
[q_input.shape[0], weight.shape[0]],
|
||||
dtype=input.dtype,
|
||||
device=q_input.device,
|
||||
)
|
||||
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
||||
else:
|
||||
if _ENABLE_JIT_DEEPGEMM:
|
||||
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
)
|
||||
else:
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
output = w8a8_block_fp8_matmul(
|
||||
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
|
||||
)
|
||||
|
||||
if bias is not None:
|
||||
output = output + bias
|
||||
return output.to(dtype=input.dtype).view(*output_shape)
|
||||
output += bias
|
||||
|
||||
return output.to(dtype=input_2d.dtype).view(*output_shape)
|
||||
|
||||
|
||||
def cutlass_w8a8_block_fp8_linear_with_fallback(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
|
||||
# TODO: add more robust shape check here
|
||||
shape_supported = weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
|
||||
|
||||
if not shape_supported:
|
||||
# fallback to triton
|
||||
return triton_w8a8_block_fp8_linear(
|
||||
input, weight, block_size, weight_scale, input_scale, bias
|
||||
)
|
||||
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=True
|
||||
)
|
||||
output = fp8_blockwise_scaled_mm(
|
||||
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input_2d.dtype
|
||||
)
|
||||
if bias is not None:
|
||||
output += bias
|
||||
return output.to(dtype=input_2d.dtype).view(*output_shape)
|
||||
|
||||
|
||||
def deepgemm_w8a8_block_fp8_linear_with_fallback(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
|
||||
output_dtype = input.dtype
|
||||
dtype_supported = output_dtype == torch.bfloat16
|
||||
|
||||
# TODO: add more robust shape check here
|
||||
shape_supported = weight.shape[0] % 128 == 0 and weight.shape[1] % 128 == 0
|
||||
|
||||
if not (shape_supported and dtype_supported):
|
||||
# fall back to triton
|
||||
return triton_w8a8_block_fp8_linear(
|
||||
input, weight, block_size, weight_scale, input_scale, bias
|
||||
)
|
||||
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
q_input, x_scale = sglang_per_token_group_quant_fp8(
|
||||
input_2d,
|
||||
block_size[1],
|
||||
column_major_scales=True,
|
||||
scale_tma_aligned=True,
|
||||
)
|
||||
output = w8a8_block_fp8_matmul_deepgemm(
|
||||
q_input, weight, x_scale, weight_scale, block_size, output_dtype=output_dtype
|
||||
)
|
||||
if bias is not None:
|
||||
output += bias
|
||||
return output.to(dtype=output_dtype).view(*output_shape)
|
||||
|
||||
|
||||
def aiter_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
output = torch.zeros(
|
||||
[q_input.shape[0], weight.shape[0]],
|
||||
dtype=input_2d.dtype,
|
||||
device=q_input.device,
|
||||
)
|
||||
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
|
||||
|
||||
if bias is not None:
|
||||
output += bias
|
||||
|
||||
return output.to(dtype=input_2d.dtype).view(*output_shape)
|
||||
|
||||
|
||||
def triton_w8a8_block_fp8_linear(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
block_size: List[int],
|
||||
weight_scale: torch.Tensor,
|
||||
input_scale: Optional[torch.Tensor] = None,
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert input_scale is None
|
||||
input_2d = input.view(-1, input.shape[-1])
|
||||
output_shape = [*input.shape[:-1], weight.shape[0]]
|
||||
|
||||
q_input, x_scale = per_token_group_quant_fp8(
|
||||
input_2d, block_size[1], column_major_scales=False
|
||||
)
|
||||
output = w8a8_block_fp8_matmul_triton(
|
||||
q_input, weight, x_scale, weight_scale, block_size, output_dtype=input_2d.dtype
|
||||
)
|
||||
if bias is not None:
|
||||
output += bias
|
||||
return output.to(dtype=input_2d.dtype).view(*output_shape)
|
||||
|
||||
|
||||
def input_to_float8(
|
||||
|
||||
Reference in New Issue
Block a user