2025-03-03 06:36:40 -08:00
|
|
|
from typing import List, Optional
|
|
|
|
|
|
|
|
|
|
import torch
|
2025-03-08 22:54:51 -08:00
|
|
|
from sgl_kernel.utils import _get_cache_buf, get_cuda_stream
|
2025-03-03 06:36:40 -08:00
|
|
|
|
|
|
|
|
|
2025-03-12 00:10:02 -07:00
|
|
|
def awq_dequantize(
|
|
|
|
|
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
|
|
|
|
|
) -> torch.ByteTensor:
|
2025-03-12 01:04:38 -07:00
|
|
|
return torch.ops.sgl_kernel.awq_dequantize(qweight, scales, qzeros)
|
2025-03-12 00:10:02 -07:00
|
|
|
|
|
|
|
|
|
2025-03-03 06:36:40 -08:00
|
|
|
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
2025-03-08 22:54:51 -08:00
|
|
|
return torch.ops.sgl_kernel.int8_scaled_mm(
|
2025-03-03 06:36:40 -08:00
|
|
|
mat_a,
|
|
|
|
|
mat_b,
|
|
|
|
|
scales_a,
|
|
|
|
|
scales_b,
|
|
|
|
|
out_dtype,
|
|
|
|
|
bias,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
|
2025-03-08 22:54:51 -08:00
|
|
|
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm(
|
2025-03-03 06:36:40 -08:00
|
|
|
mat_a,
|
|
|
|
|
mat_b,
|
|
|
|
|
scales_a,
|
|
|
|
|
scales_b,
|
|
|
|
|
out_dtype,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
|
2025-03-08 22:54:51 -08:00
|
|
|
return torch.ops.sgl_kernel.fp8_scaled_mm(
|
2025-03-03 06:36:40 -08:00
|
|
|
mat_a,
|
|
|
|
|
mat_b,
|
|
|
|
|
scales_a,
|
|
|
|
|
scales_b,
|
|
|
|
|
out_dtype,
|
|
|
|
|
bias,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _bmm_fp8_internal(
|
|
|
|
|
workspace_buffer: torch.Tensor,
|
|
|
|
|
A: torch.Tensor,
|
|
|
|
|
B: torch.Tensor,
|
|
|
|
|
D: torch.Tensor,
|
|
|
|
|
A_scale: torch.Tensor,
|
|
|
|
|
B_scale: torch.Tensor,
|
|
|
|
|
) -> None:
|
|
|
|
|
cublas_handle = torch.cuda.current_blas_handle()
|
2025-03-08 22:54:51 -08:00
|
|
|
torch.ops.sgl_kernel.bmm_fp8(
|
2025-03-03 06:36:40 -08:00
|
|
|
A,
|
|
|
|
|
B,
|
|
|
|
|
D,
|
|
|
|
|
A_scale,
|
|
|
|
|
B_scale,
|
|
|
|
|
workspace_buffer,
|
|
|
|
|
cublas_handle,
|
|
|
|
|
get_cuda_stream(),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bmm_fp8(
|
|
|
|
|
A: torch.Tensor,
|
|
|
|
|
B: torch.Tensor,
|
|
|
|
|
A_scale: torch.Tensor,
|
|
|
|
|
B_scale: torch.Tensor,
|
|
|
|
|
dtype: torch.dtype,
|
|
|
|
|
out: Optional[torch.Tensor] = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if out is None:
|
|
|
|
|
out = torch.empty(
|
|
|
|
|
(A.shape[0], A.shape[1], B.shape[2]),
|
|
|
|
|
device=A.device,
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
)
|
|
|
|
|
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
|
|
|
|
|
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sgl_per_token_group_quant_fp8(
|
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
output_q: torch.Tensor,
|
|
|
|
|
output_s: torch.Tensor,
|
|
|
|
|
group_size: int,
|
|
|
|
|
eps: float,
|
|
|
|
|
fp8_min: float,
|
|
|
|
|
fp8_max: float,
|
|
|
|
|
) -> None:
|
2025-03-08 22:54:51 -08:00
|
|
|
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8(
|
2025-03-03 06:36:40 -08:00
|
|
|
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
2025-03-07 10:05:43 +08:00
|
|
|
def sgl_per_tensor_quant_fp8(
|
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
output_q: torch.Tensor,
|
|
|
|
|
output_s: torch.Tensor,
|
|
|
|
|
is_static: bool,
|
|
|
|
|
) -> None:
|
2025-03-08 22:54:51 -08:00
|
|
|
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static)
|
2025-03-07 10:05:43 +08:00
|
|
|
|
|
|
|
|
|
2025-03-03 06:36:40 -08:00
|
|
|
def cublas_grouped_gemm(
|
|
|
|
|
inputs: List[torch.Tensor],
|
|
|
|
|
weights: List[torch.Tensor],
|
|
|
|
|
outputs: List[torch.Tensor],
|
|
|
|
|
out_dtype: torch.dtype,
|
|
|
|
|
) -> None:
|
|
|
|
|
assert (
|
|
|
|
|
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
|
|
|
|
), "Inputs/weights/outputs should not be empty!"
|
|
|
|
|
cublas_handle = torch.cuda.current_blas_handle()
|
2025-03-08 22:54:51 -08:00
|
|
|
torch.ops.sgl_kernel.cublas_grouped_gemm(
|
2025-03-03 06:36:40 -08:00
|
|
|
inputs,
|
|
|
|
|
weights,
|
|
|
|
|
outputs,
|
|
|
|
|
out_dtype,
|
|
|
|
|
cublas_handle,
|
|
|
|
|
get_cuda_stream(),
|
|
|
|
|
)
|
2025-03-06 20:53:05 -08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def sgl_per_token_quant_fp8(
|
|
|
|
|
input: torch.Tensor,
|
|
|
|
|
output_q: torch.Tensor,
|
|
|
|
|
output_s: torch.Tensor,
|
|
|
|
|
) -> None:
|
2025-03-08 22:54:51 -08:00
|
|
|
torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s)
|