Support FP4 gemm (1/2) (#3899)
This commit is contained in:
@@ -26,9 +26,11 @@ from sgl_kernel.gemm import (
|
||||
awq_dequantize,
|
||||
bmm_fp8,
|
||||
cublas_grouped_gemm,
|
||||
cutlass_scaled_fp4_mm,
|
||||
fp8_blockwise_scaled_mm,
|
||||
fp8_scaled_mm,
|
||||
int8_scaled_mm,
|
||||
scaled_fp4_quant,
|
||||
sgl_per_tensor_quant_fp8,
|
||||
sgl_per_token_group_quant_fp8,
|
||||
sgl_per_token_group_quant_int8,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
||||
@@ -145,3 +145,73 @@ def sgl_per_token_quant_fp8(
|
||||
output_s: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s)
|
||||
|
||||
|
||||
def cutlass_scaled_fp4_mm(
|
||||
a: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
block_scale_a: torch.Tensor,
|
||||
block_scale_b: torch.Tensor,
|
||||
alpha: torch.Tensor,
|
||||
out_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
assert a.ndim == 2 and b.ndim == 2
|
||||
m, n = a.shape[0], b.shape[0]
|
||||
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
|
||||
torch.ops.sgl_kernels.cutlass_scaled_fp4_mm(
|
||||
out, a, b, block_scale_a, block_scale_b, alpha
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def scaled_fp4_quant(
|
||||
input: torch.Tensor, input_global_scale: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantize input tensor to FP4 and return quantized tensor and scale.
|
||||
|
||||
This function quantizes the last dimension of the given tensor `input`. For
|
||||
every 16 consecutive elements, a single dynamically computed scaling factor
|
||||
is shared. This scaling factor is quantized using the `input_global_scale`
|
||||
and is stored in a swizzled layout (see
|
||||
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
|
||||
|
||||
Args:
|
||||
input: The input tensor to be quantized to FP4
|
||||
input_global_scale: A scalar scaling factor for the entire tensor.
|
||||
|
||||
Returns:
|
||||
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
|
||||
two values are packed into a uint8 and float8_e4m3 scaling factors
|
||||
in a sizzled layout.
|
||||
"""
|
||||
assert input.ndim >= 1, f"input.ndim needs to be >= 1, but got {input.ndim}."
|
||||
other_dims = 1 if input.ndim == 1 else -1
|
||||
input = input.reshape(other_dims, input.shape[-1])
|
||||
m, n = input.shape
|
||||
block_size = 16
|
||||
device = input.device
|
||||
|
||||
assert n % block_size == 0, f"last dim has to be multiple of 16, but got {n}."
|
||||
assert input.dtype in (
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
), f"input.dtype needs to be fp16 or bf16 but got {input.dtype}."
|
||||
|
||||
# Two fp4 values will be packed into an uint8.
|
||||
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||
|
||||
# We use the rounded values to store the swizzled values. Then, the scaling
|
||||
# factors in float8_e4m3fn are packed into an int32 for every 4 values.
|
||||
rounded_m = ((m + 128 - 1) // 128) * 128
|
||||
scale_n = n // block_size
|
||||
rounded_n = ((scale_n + 4 - 1) // 4) * 4
|
||||
output_scale = torch.empty(
|
||||
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernels.scaled_fp4_quant(
|
||||
output, input, output_scale, input_global_scale
|
||||
)
|
||||
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||
return output, output_scale
|
||||
|
||||
Reference in New Issue
Block a user