fix: remove cublas_grouped_gemm (#5307)
This commit is contained in:
@@ -164,7 +164,6 @@ set(SOURCES
|
|||||||
"csrc/elementwise/rope.cu"
|
"csrc/elementwise/rope.cu"
|
||||||
"csrc/gemm/awq_kernel.cu"
|
"csrc/gemm/awq_kernel.cu"
|
||||||
"csrc/gemm/bmm_fp8.cu"
|
"csrc/gemm/bmm_fp8.cu"
|
||||||
"csrc/gemm/cublas_grouped_gemm.cu"
|
|
||||||
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
|
"csrc/gemm/fp8_blockwise_gemm_kernel.cu"
|
||||||
"csrc/gemm/fp8_gemm_kernel.cu"
|
"csrc/gemm/fp8_gemm_kernel.cu"
|
||||||
"csrc/gemm/int8_gemm_kernel.cu"
|
"csrc/gemm/int8_gemm_kernel.cu"
|
||||||
|
|||||||
@@ -1,262 +0,0 @@
|
|||||||
import argparse
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
from sgl_kernel import cublas_grouped_gemm
|
|
||||||
|
|
||||||
WEIGHT_CONFIGS = {
|
|
||||||
"DeepSeek-V2-Lite": {
|
|
||||||
"num_routed_experts": 64,
|
|
||||||
"ffn_shapes": [
|
|
||||||
[2048, 2816],
|
|
||||||
[1408, 2048],
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"DeepSeek-V2": {
|
|
||||||
"num_routed_experts": 160,
|
|
||||||
"ffn_shapes": [
|
|
||||||
[5120, 3072],
|
|
||||||
[1536, 5120],
|
|
||||||
],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# This Triton Grouped Gemm Kernel is adapted from
|
|
||||||
# https://triton-lang.org/main/getting-started/tutorials/08-grouped-gemm.html
|
|
||||||
@triton.jit
|
|
||||||
def grouped_matmul_kernel(
|
|
||||||
# device tensor of matrices pointers
|
|
||||||
group_a_ptrs,
|
|
||||||
group_b_ptrs,
|
|
||||||
group_c_ptrs,
|
|
||||||
# device tensor of gemm sizes. its shape is [group_size, 3]
|
|
||||||
# dim 0 is group_size, dim 1 is the values of <M, N, K> of each gemm
|
|
||||||
group_gemm_sizes,
|
|
||||||
# device tensor of leading dimension sizes. its shape is [group_size, 3]
|
|
||||||
# dim 0 is group_size, dim 1 is the values of <lda, ldb, ldc> of each gemm
|
|
||||||
g_lds,
|
|
||||||
# Factors for multiplication.
|
|
||||||
alphas,
|
|
||||||
betas,
|
|
||||||
# number of gemms
|
|
||||||
group_size,
|
|
||||||
# number of virtual SM
|
|
||||||
NUM_SM: tl.constexpr,
|
|
||||||
# tile sizes
|
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
|
||||||
BLOCK_SIZE_K: tl.constexpr,
|
|
||||||
):
|
|
||||||
tile_idx = tl.program_id(0)
|
|
||||||
last_problem_end = 0
|
|
||||||
for g in range(group_size):
|
|
||||||
# get the gemm size of the current problem
|
|
||||||
gm = tl.load(group_gemm_sizes + g * 3)
|
|
||||||
gn = tl.load(group_gemm_sizes + g * 3 + 1)
|
|
||||||
gk = tl.load(group_gemm_sizes + g * 3 + 2)
|
|
||||||
num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M)
|
|
||||||
num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N)
|
|
||||||
num_tiles = num_m_tiles * num_n_tiles
|
|
||||||
# load multiplication factors
|
|
||||||
alpha = tl.load(alphas + g)
|
|
||||||
beta = tl.load(betas + g)
|
|
||||||
# iterate through the tiles in the current gemm problem
|
|
||||||
while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles:
|
|
||||||
# pick up a tile from the current gemm problem
|
|
||||||
k = gk
|
|
||||||
lda = tl.load(g_lds + g * 3)
|
|
||||||
ldb = tl.load(g_lds + g * 3 + 1)
|
|
||||||
ldc = tl.load(g_lds + g * 3 + 2)
|
|
||||||
a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16))
|
|
||||||
b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16))
|
|
||||||
c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16))
|
|
||||||
# figure out tile coordinates
|
|
||||||
tile_idx_in_gemm = tile_idx - last_problem_end
|
|
||||||
tile_m_idx = tile_idx_in_gemm // num_n_tiles
|
|
||||||
tile_n_idx = tile_idx_in_gemm % num_n_tiles
|
|
||||||
|
|
||||||
# do regular gemm here
|
|
||||||
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
||||||
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
|
||||||
a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :]
|
|
||||||
b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :]
|
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
|
||||||
for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)):
|
|
||||||
a = tl.load(
|
|
||||||
a_ptrs,
|
|
||||||
mask=(offs_am[:, None] < gm)
|
|
||||||
and (offs_k[None, :] < gk - kk * BLOCK_SIZE_K),
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
b = tl.load(
|
|
||||||
b_ptrs,
|
|
||||||
mask=(offs_k[:, None] < gk - kk * BLOCK_SIZE_K)
|
|
||||||
and (offs_bn[None, :] < gn),
|
|
||||||
other=0.0,
|
|
||||||
)
|
|
||||||
accumulator += tl.dot(a, b)
|
|
||||||
a_ptrs += BLOCK_SIZE_K
|
|
||||||
b_ptrs += BLOCK_SIZE_K * ldb
|
|
||||||
accumulator *= alpha
|
|
||||||
c = accumulator.to(tl.float16)
|
|
||||||
|
|
||||||
offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
|
||||||
offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
|
||||||
c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :]
|
|
||||||
output_mask = (offs_am[:, None] < gm) and (offs_bn[None, :] < gn)
|
|
||||||
c += beta * tl.load(c_ptrs, mask=output_mask)
|
|
||||||
tl.store(c_ptrs, c, mask=output_mask)
|
|
||||||
|
|
||||||
# go to the next tile by advancing NUM_SM
|
|
||||||
tile_idx += NUM_SM
|
|
||||||
|
|
||||||
# get ready to go to the next gemm problem
|
|
||||||
last_problem_end = last_problem_end + num_tiles
|
|
||||||
|
|
||||||
|
|
||||||
def triton_perf_fn(group_A, group_B, group_C, dtype):
|
|
||||||
# We put the process of matrix lengths and pointers here out of fairness,
|
|
||||||
# since cublas_grouped_gemm kernel also does these work.
|
|
||||||
group_size = len(group_A)
|
|
||||||
A_addrs = []
|
|
||||||
B_addrs = []
|
|
||||||
C_addrs = []
|
|
||||||
g_sizes = []
|
|
||||||
g_lds = []
|
|
||||||
alphas = [1.0] * group_size
|
|
||||||
betas = [0.0] * group_size
|
|
||||||
for i in range(group_size):
|
|
||||||
M, N, K = group_A[i].shape[0], group_B[i].shape[1], group_A[i].shape[1]
|
|
||||||
g_sizes += [M, N, K]
|
|
||||||
g_lds += [K, N, N]
|
|
||||||
A_addrs.append(group_A[i].data_ptr())
|
|
||||||
B_addrs.append(group_B[i].data_ptr())
|
|
||||||
C_addrs.append(group_C[i].data_ptr())
|
|
||||||
|
|
||||||
d_a_ptrs = torch.tensor(A_addrs, device="cuda")
|
|
||||||
d_b_ptrs = torch.tensor(B_addrs, device="cuda")
|
|
||||||
d_c_ptrs = torch.tensor(C_addrs, device="cuda")
|
|
||||||
d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="cuda")
|
|
||||||
d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="cuda")
|
|
||||||
d_alphas = torch.tensor(alphas, dtype=torch.float32, device="cuda")
|
|
||||||
d_betas = torch.tensor(betas, dtype=torch.float32, device="cuda")
|
|
||||||
|
|
||||||
NUM_SM = 128
|
|
||||||
grid = (NUM_SM,)
|
|
||||||
grouped_matmul_kernel[grid](
|
|
||||||
d_a_ptrs,
|
|
||||||
d_b_ptrs,
|
|
||||||
d_c_ptrs,
|
|
||||||
d_g_sizes,
|
|
||||||
d_g_lds,
|
|
||||||
d_alphas,
|
|
||||||
d_betas,
|
|
||||||
group_size,
|
|
||||||
NUM_SM=NUM_SM,
|
|
||||||
BLOCK_SIZE_M=128,
|
|
||||||
BLOCK_SIZE_N=128,
|
|
||||||
BLOCK_SIZE_K=32,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def cublas_perf_fn(group_A, group_B, group_C, dtype):
|
|
||||||
cublas_grouped_gemm(group_A, group_B, group_C, dtype)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.testing.perf_report(
|
|
||||||
triton.testing.Benchmark(
|
|
||||||
x_names=["M"],
|
|
||||||
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
|
|
||||||
x_log=False,
|
|
||||||
line_arg="provider",
|
|
||||||
line_vals=[
|
|
||||||
"triton",
|
|
||||||
"cublas",
|
|
||||||
],
|
|
||||||
line_names=[
|
|
||||||
"triton",
|
|
||||||
"cublas",
|
|
||||||
],
|
|
||||||
styles=[("green", "-"), ("blue", "-")],
|
|
||||||
ylabel="gbps",
|
|
||||||
plot_name="grouped gemm",
|
|
||||||
args={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
def benchmark(M, provider, N, K):
|
|
||||||
group_size = 20 # Number of used experts per gpu is usually around 20
|
|
||||||
group_A = []
|
|
||||||
group_B_row_major = []
|
|
||||||
group_B_col_major = []
|
|
||||||
group_C = []
|
|
||||||
dtype = torch.float16
|
|
||||||
for i in range(group_size):
|
|
||||||
A = torch.rand((M, K), device="cuda", dtype=dtype)
|
|
||||||
B_row_major = torch.rand((K, N), device="cuda", dtype=dtype)
|
|
||||||
B_col_major = torch.rand((N, K), device="cuda", dtype=dtype)
|
|
||||||
C = torch.empty((M, N), device="cuda", dtype=dtype)
|
|
||||||
group_A.append(A)
|
|
||||||
group_B_row_major.append(B_row_major)
|
|
||||||
group_B_col_major.append(B_col_major)
|
|
||||||
group_C.append(C)
|
|
||||||
|
|
||||||
quantiles = [0.5, 0.2, 0.8]
|
|
||||||
if "triton" in provider:
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
||||||
lambda: triton_perf_fn(group_A, group_B_row_major, group_C, dtype),
|
|
||||||
quantiles=quantiles,
|
|
||||||
)
|
|
||||||
elif "cublas" in provider:
|
|
||||||
ms, min_ms, max_ms = triton.testing.do_bench(
|
|
||||||
lambda: cublas_perf_fn(group_A, group_B_col_major, group_C, dtype),
|
|
||||||
quantiles=quantiles,
|
|
||||||
)
|
|
||||||
|
|
||||||
gbps = (
|
|
||||||
lambda ms: group_size
|
|
||||||
* (2 * M * N * K + 2 * M * N)
|
|
||||||
* group_A[0].element_size()
|
|
||||||
* 1e-9
|
|
||||||
/ (ms * 1e-3)
|
|
||||||
)
|
|
||||||
return gbps(ms), gbps(max_ms), gbps(min_ms)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument(
|
|
||||||
"--models",
|
|
||||||
nargs="+",
|
|
||||||
type=str,
|
|
||||||
default=["DeepSeek-V2"],
|
|
||||||
help="List of models to benchmark",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--tp-size",
|
|
||||||
type=int,
|
|
||||||
default=8,
|
|
||||||
help="Tensor parallel size",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
for model in args.models:
|
|
||||||
assert model in WEIGHT_CONFIGS
|
|
||||||
num_experts_per_device = (
|
|
||||||
WEIGHT_CONFIGS[model]["num_routed_experts"] // args.tp_size
|
|
||||||
)
|
|
||||||
for K, N in WEIGHT_CONFIGS[model]["ffn_shapes"]:
|
|
||||||
print(
|
|
||||||
f"{model} N={N} K={K} tp_size={args.tp_size} "
|
|
||||||
f"group_size=num_experts_per_device={num_experts_per_device}: "
|
|
||||||
)
|
|
||||||
benchmark.run(
|
|
||||||
print_data=True,
|
|
||||||
show_plots=True,
|
|
||||||
save_path="bench_grouped_gemm_res",
|
|
||||||
N=N,
|
|
||||||
K=K,
|
|
||||||
)
|
|
||||||
|
|
||||||
print("Benchmark finished!")
|
|
||||||
@@ -112,11 +112,6 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
|||||||
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
|
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
|
||||||
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
|
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
|
||||||
|
|
||||||
m.def(
|
|
||||||
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
|
|
||||||
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
|
|
||||||
m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm);
|
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
|
||||||
" Tensor block_scale_a, Tensor block_scale_b,"
|
" Tensor block_scale_a, Tensor block_scale_b,"
|
||||||
|
|||||||
@@ -1,172 +0,0 @@
|
|||||||
// References:
|
|
||||||
// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmgroupedbatchedex
|
|
||||||
// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLAS/Extensions/GemmGroupedBatchedEx/cublas_GemmGroupedBatchedEx_example.cu
|
|
||||||
// https://github.com/zhihu/ZhiLight/blob/main/src/nn/linear/gemm_grouped.cpp
|
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
|
||||||
#include <c10/cuda/CUDAGuard.h>
|
|
||||||
#include <c10/util/Exception.h>
|
|
||||||
#include <cublas_v2.h>
|
|
||||||
#include <cudaTypedefs.h>
|
|
||||||
#include <cuda_fp16.h>
|
|
||||||
#include <cuda_runtime.h>
|
|
||||||
#include <torch/all.h>
|
|
||||||
|
|
||||||
#include <cstdio>
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "utils.h"
|
|
||||||
|
|
||||||
static void check_group_count(
|
|
||||||
const std::vector<torch::Tensor>& inputs,
|
|
||||||
const std::vector<torch::Tensor>& weights,
|
|
||||||
const std::vector<torch::Tensor>& outputs) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
((inputs.size() == weights.size()) && (inputs.size() == outputs.size())),
|
|
||||||
"The group count of inputs, weights and outputs should be the same.");
|
|
||||||
}
|
|
||||||
|
|
||||||
static void check_device_dtype(const torch::Dtype& dtype, const std::vector<torch::Tensor>& tensors) {
|
|
||||||
for (const auto& t : tensors) {
|
|
||||||
TORCH_CHECK(dtype == t.dtype(), "dtype of all the tensors should be the same");
|
|
||||||
TORCH_CHECK(t.is_cuda(), "All tensors should be in Cuda memory");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::vector<int> get_dims(const std::vector<torch::Tensor>& tensors, int dim) {
|
|
||||||
std::vector<int> results;
|
|
||||||
for (const auto& t : tensors) {
|
|
||||||
TORCH_CHECK(t.dim() == 2, "Should pass in 2D matrices");
|
|
||||||
results.push_back(t.size(dim));
|
|
||||||
}
|
|
||||||
return std::move(results);
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::vector<int> get_strides(const std::vector<torch::Tensor>& tensors, int dim) {
|
|
||||||
std::vector<int> results;
|
|
||||||
for (const auto& t : tensors) {
|
|
||||||
results.push_back(t.stride(dim));
|
|
||||||
}
|
|
||||||
return std::move(results);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void check_equal(const std::vector<int>& a, const std::vector<int>& b, const std::string& err_msg) {
|
|
||||||
for (int i = 0; i < a.size(); ++i) {
|
|
||||||
TORCH_CHECK(a[i] == b[i], err_msg);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static std::vector<void*> get_tensor_ptrs(const std::vector<torch::Tensor>& tensors) {
|
|
||||||
std::vector<void*> ptrs;
|
|
||||||
for (auto& t : tensors) {
|
|
||||||
ptrs.push_back(t.data_ptr());
|
|
||||||
}
|
|
||||||
return std::move(ptrs);
|
|
||||||
}
|
|
||||||
|
|
||||||
static torch::Tensor create_ptr_pointer(const std::vector<void*>& ptrs, cudaStream_t stream) {
|
|
||||||
auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA);
|
|
||||||
torch::Tensor gpu_ptrs = torch::empty({static_cast<int>(ptrs.size())}, options);
|
|
||||||
TORCH_CHECK(
|
|
||||||
cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, stream) ==
|
|
||||||
CUBLAS_STATUS_SUCCESS);
|
|
||||||
return gpu_ptrs;
|
|
||||||
}
|
|
||||||
|
|
||||||
// We want compute input @ weight^T in row major
|
|
||||||
// This is equivalent to computing weight @ input^T in col major
|
|
||||||
// Cublas only accepts matrix in column major, so this arrangement is needed
|
|
||||||
void cublas_grouped_gemm(
|
|
||||||
const std::vector<torch::Tensor>& inputs, // b: (m, k) row major = (k, m) col major
|
|
||||||
const std::vector<torch::Tensor>& weights, // a: (n, k) row major = (n, k)^T col major
|
|
||||||
const std::vector<torch::Tensor>& outputs, // c: (m, n) row major = (n, m) col major
|
|
||||||
const torch::Dtype& out_dtype,
|
|
||||||
int64_t cublas_handle,
|
|
||||||
int64_t cuda_stream) {
|
|
||||||
TORCH_CHECK(
|
|
||||||
out_dtype == torch::kHalf || out_dtype == torch::kBFloat16,
|
|
||||||
"cublas grouped_gemm can"
|
|
||||||
"only be applied to float16 and bfloat16 dtype");
|
|
||||||
|
|
||||||
int group_count = inputs.size();
|
|
||||||
check_group_count(inputs, weights, outputs);
|
|
||||||
std::vector<int> group_size(group_count, 1);
|
|
||||||
|
|
||||||
// Make sure all tensors are on cuda and use the same dtype
|
|
||||||
check_device_dtype(out_dtype, inputs);
|
|
||||||
check_device_dtype(out_dtype, weights);
|
|
||||||
check_device_dtype(out_dtype, outputs);
|
|
||||||
|
|
||||||
// Weights should be transposed to (n, k) of column major
|
|
||||||
std::vector<cublasOperation_t> transa_array(group_count, CUBLAS_OP_T);
|
|
||||||
std::vector<cublasOperation_t> transb_array(group_count, CUBLAS_OP_N);
|
|
||||||
|
|
||||||
// Get dim arrays
|
|
||||||
std::vector<int> m_array = get_dims(weights, 0);
|
|
||||||
std::vector<int> n_array = get_dims(inputs, 0);
|
|
||||||
std::vector<int> k_array = get_dims(inputs, 1);
|
|
||||||
|
|
||||||
// Make sure the dimensions in each group match
|
|
||||||
std::vector<int> m_array1 = get_dims(outputs, 1);
|
|
||||||
std::vector<int> n_array1 = get_dims(outputs, 0);
|
|
||||||
std::vector<int> k_array1 = get_dims(weights, 1);
|
|
||||||
check_equal(m_array, m_array1, "sizes don't match on m dimension");
|
|
||||||
check_equal(n_array, n_array1, "sizes don't match on n dimension");
|
|
||||||
check_equal(k_array, k_array1, "sizes don't match on k dimension");
|
|
||||||
|
|
||||||
// Get leading dimensions
|
|
||||||
std::vector<int> lda_array = get_strides(weights, 0);
|
|
||||||
std::vector<int> ldb_array = get_strides(inputs, 0);
|
|
||||||
std::vector<int> ldc_array = get_strides(outputs, 0);
|
|
||||||
|
|
||||||
// Use default scaling factors
|
|
||||||
std::vector<float> alpha_array(group_count, 1);
|
|
||||||
std::vector<float> beta_array(group_count, 0);
|
|
||||||
|
|
||||||
std::vector<void*> a_array = get_tensor_ptrs(weights);
|
|
||||||
std::vector<void*> b_array = get_tensor_ptrs(inputs);
|
|
||||||
std::vector<void*> c_array = get_tensor_ptrs(outputs);
|
|
||||||
|
|
||||||
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
|
||||||
|
|
||||||
// Should allocate tensors for storage of pointers
|
|
||||||
torch::Tensor d_a = create_ptr_pointer(a_array, stream);
|
|
||||||
torch::Tensor d_b = create_ptr_pointer(b_array, stream);
|
|
||||||
torch::Tensor d_c = create_ptr_pointer(c_array, stream);
|
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
|
||||||
auto handle = reinterpret_cast<cublasHandle_t>(cublas_handle);
|
|
||||||
cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF);
|
|
||||||
|
|
||||||
auto status = cublasGemmGroupedBatchedEx(
|
|
||||||
handle,
|
|
||||||
transa_array.data(),
|
|
||||||
transb_array.data(),
|
|
||||||
m_array.data(),
|
|
||||||
n_array.data(),
|
|
||||||
k_array.data(),
|
|
||||||
alpha_array.data(),
|
|
||||||
(void**)d_a.data_ptr(),
|
|
||||||
cuda_data_type,
|
|
||||||
lda_array.data(),
|
|
||||||
(void**)d_b.data_ptr(),
|
|
||||||
cuda_data_type,
|
|
||||||
ldb_array.data(),
|
|
||||||
beta_array.data(),
|
|
||||||
(void**)d_c.data_ptr(),
|
|
||||||
cuda_data_type,
|
|
||||||
ldc_array.data(),
|
|
||||||
group_count,
|
|
||||||
group_size.data(),
|
|
||||||
CUBLAS_COMPUTE_32F);
|
|
||||||
TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status));
|
|
||||||
TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization");
|
|
||||||
return;
|
|
||||||
#endif
|
|
||||||
|
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
|
||||||
false, "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion());
|
|
||||||
}
|
|
||||||
@@ -160,13 +160,6 @@ void sgl_per_token_group_quant_int8(
|
|||||||
double int8_max);
|
double int8_max);
|
||||||
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
|
void sgl_per_tensor_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, bool is_static);
|
||||||
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
|
void sgl_per_token_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s);
|
||||||
void cublas_grouped_gemm(
|
|
||||||
const std::vector<torch::Tensor>& inputs,
|
|
||||||
const std::vector<torch::Tensor>& weights,
|
|
||||||
const std::vector<torch::Tensor>& outputs,
|
|
||||||
const torch::Dtype& out_dtype,
|
|
||||||
int64_t cublas_handle,
|
|
||||||
int64_t cuda_stream);
|
|
||||||
void bmm_fp8(
|
void bmm_fp8(
|
||||||
at::Tensor A,
|
at::Tensor A,
|
||||||
at::Tensor B,
|
at::Tensor B,
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ from sgl_kernel.elementwise import (
|
|||||||
from sgl_kernel.gemm import (
|
from sgl_kernel.gemm import (
|
||||||
awq_dequantize,
|
awq_dequantize,
|
||||||
bmm_fp8,
|
bmm_fp8,
|
||||||
cublas_grouped_gemm,
|
|
||||||
cutlass_scaled_fp4_mm,
|
cutlass_scaled_fp4_mm,
|
||||||
fp8_blockwise_scaled_mm,
|
fp8_blockwise_scaled_mm,
|
||||||
fp8_scaled_mm,
|
fp8_scaled_mm,
|
||||||
|
|||||||
@@ -121,26 +121,6 @@ def sgl_per_tensor_quant_fp8(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def cublas_grouped_gemm(
|
|
||||||
inputs: List[torch.Tensor],
|
|
||||||
weights: List[torch.Tensor],
|
|
||||||
outputs: List[torch.Tensor],
|
|
||||||
out_dtype: torch.dtype,
|
|
||||||
) -> None:
|
|
||||||
assert (
|
|
||||||
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
|
|
||||||
), "Inputs/weights/outputs should not be empty!"
|
|
||||||
cublas_handle = torch.cuda.current_blas_handle()
|
|
||||||
torch.ops.sgl_kernel.cublas_grouped_gemm.default(
|
|
||||||
inputs,
|
|
||||||
weights,
|
|
||||||
outputs,
|
|
||||||
out_dtype,
|
|
||||||
cublas_handle,
|
|
||||||
get_cuda_stream(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sgl_per_token_quant_fp8(
|
def sgl_per_token_quant_fp8(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
output_q: torch.Tensor,
|
output_q: torch.Tensor,
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from sgl_kernel import cublas_grouped_gemm
|
|
||||||
|
|
||||||
|
|
||||||
def torch_grouped_gemm(a_array, b_array, out_dtype):
|
|
||||||
return [torch.matmul(a, b.t()).to(out_dtype) for a, b in zip(a_array, b_array)]
|
|
||||||
|
|
||||||
|
|
||||||
skip_condition = not torch.cuda.is_available() or (
|
|
||||||
torch.version.cuda is None
|
|
||||||
or tuple(map(int, torch.version.cuda.split("."))) < (12, 5)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
skip_condition, reason="CUDA not available or CUDA version lower than 12.5"
|
|
||||||
)
|
|
||||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.bfloat16])
|
|
||||||
@pytest.mark.parametrize("M", [1, 16, 32, 256, 1024])
|
|
||||||
@pytest.mark.parametrize("N", [2, 16, 128, 256, 4096])
|
|
||||||
@pytest.mark.parametrize("K", [3, 16, 32, 512, 8192])
|
|
||||||
def test_grouped_gemm_accuracy(out_dtype, M, N, K):
|
|
||||||
a = torch.randn((M, K), device="cuda", dtype=out_dtype) * 5
|
|
||||||
b = torch.randn((N, K), device="cuda", dtype=out_dtype) * 5
|
|
||||||
expected = torch.matmul(a, b.t()).to(out_dtype)
|
|
||||||
|
|
||||||
a_array = [a]
|
|
||||||
b_array = [b]
|
|
||||||
c_array = [torch.empty((M, N), device="cuda", dtype=out_dtype)]
|
|
||||||
|
|
||||||
result_torch = torch_grouped_gemm(a_array, b_array, out_dtype)[0]
|
|
||||||
cublas_grouped_gemm(a_array, b_array, c_array, out_dtype)
|
|
||||||
|
|
||||||
torch.testing.assert_close(result_torch, expected)
|
|
||||||
torch.testing.assert_close(c_array[0], expected)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
pytest.main([__file__])
|
|
||||||
Reference in New Issue
Block a user