diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 8c67c2eda..3b4d64930 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -77,3 +77,21 @@ Overall, with these optimizations, we have achieved up to a 7x acceleration in o - **Weight**: Per-128x128-block quantization for better numerical stability. **Usage**: turn on by default for DeepSeek V3 models. + +### Cublas Grouped Gemm + +**Description**: [Grouped Gemm API](https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmgroupedbatchedex) provided by Cublas 12.5 is attached to SGLang for acceleration of +settings where a group of matrix multiplication with different shapes needs to be executed. Typical examples are expert parallel in MoE layers, and lora modules in multi-serving Lora layers. + +**Usage**: SGLang currently only supports Pytorch 2.5, which is installed with Cuda 12.4 packages together. Users need to work on a Cuda environment >= 12.5 and forcely upgrade the Cublas package in the following way: + +1. Make sure the system Cuda version is >= 12.5 with `nvcc -V` +2. Install sglang under instruction of [official document ](https://docs.sglang.ai/start/install.html) +3. Reinstall cublas 12.5 through `pip install nvidia-cublas-cu12==12.5.3.2` so that the cublas package is upgraded +4. Compile the new sgl-kernel library with `cd sgl-kernel && make build` + +Then the cublas grouped gemm kernel can be imported with +```python +from sgl_kernel import cublas_grouped_gemm +``` +Currently Cublas only support grouped gemm kernel for fp16/bf16/fp32 tensors, so fp8 tensors cannot be applied. diff --git a/sgl-kernel/benchmark/bench_cublas_grouped_gemm.py b/sgl-kernel/benchmark/bench_cublas_grouped_gemm.py new file mode 100644 index 000000000..15f4d7b7d --- /dev/null +++ b/sgl-kernel/benchmark/bench_cublas_grouped_gemm.py @@ -0,0 +1,262 @@ +import argparse + +import torch +import triton +import triton.language as tl +from sgl_kernel import cublas_grouped_gemm + +WEIGHT_CONFIGS = { + "DeepSeek-V2-Lite": { + "num_routed_experts": 64, + "ffn_shapes": [ + [2048, 2816], + [1408, 2048], + ], + }, + "DeepSeek-V2": { + "num_routed_experts": 160, + "ffn_shapes": [ + [5120, 3072], + [1536, 5120], + ], + }, +} + + +# This Triton Grouped Gemm Kernel is adapted from +# https://triton-lang.org/main/getting-started/tutorials/08-grouped-gemm.html +@triton.jit +def grouped_matmul_kernel( + # device tensor of matrices pointers + group_a_ptrs, + group_b_ptrs, + group_c_ptrs, + # device tensor of gemm sizes. its shape is [group_size, 3] + # dim 0 is group_size, dim 1 is the values of of each gemm + group_gemm_sizes, + # device tensor of leading dimension sizes. its shape is [group_size, 3] + # dim 0 is group_size, dim 1 is the values of of each gemm + g_lds, + # Factors for multiplication. + alphas, + betas, + # number of gemms + group_size, + # number of virtual SM + NUM_SM: tl.constexpr, + # tile sizes + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + tile_idx = tl.program_id(0) + last_problem_end = 0 + for g in range(group_size): + # get the gemm size of the current problem + gm = tl.load(group_gemm_sizes + g * 3) + gn = tl.load(group_gemm_sizes + g * 3 + 1) + gk = tl.load(group_gemm_sizes + g * 3 + 2) + num_m_tiles = tl.cdiv(gm, BLOCK_SIZE_M) + num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) + num_tiles = num_m_tiles * num_n_tiles + # load multiplication factors + alpha = tl.load(alphas + g) + beta = tl.load(betas + g) + # iterate through the tiles in the current gemm problem + while tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles: + # pick up a tile from the current gemm problem + k = gk + lda = tl.load(g_lds + g * 3) + ldb = tl.load(g_lds + g * 3 + 1) + ldc = tl.load(g_lds + g * 3 + 2) + a_ptr = tl.load(group_a_ptrs + g).to(tl.pointer_type(tl.float16)) + b_ptr = tl.load(group_b_ptrs + g).to(tl.pointer_type(tl.float16)) + c_ptr = tl.load(group_c_ptrs + g).to(tl.pointer_type(tl.float16)) + # figure out tile coordinates + tile_idx_in_gemm = tile_idx - last_problem_end + tile_m_idx = tile_idx_in_gemm // num_n_tiles + tile_n_idx = tile_idx_in_gemm % num_n_tiles + + # do regular gemm here + offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] + b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=(offs_am[:, None] < gm) + and (offs_k[None, :] < gk - kk * BLOCK_SIZE_K), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < gk - kk * BLOCK_SIZE_K) + and (offs_bn[None, :] < gn), + other=0.0, + ) + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K * ldb + accumulator *= alpha + c = accumulator.to(tl.float16) + + offs_cm = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + ldc * offs_cm[:, None] + offs_cn[None, :] + output_mask = (offs_am[:, None] < gm) and (offs_bn[None, :] < gn) + c += beta * tl.load(c_ptrs, mask=output_mask) + tl.store(c_ptrs, c, mask=output_mask) + + # go to the next tile by advancing NUM_SM + tile_idx += NUM_SM + + # get ready to go to the next gemm problem + last_problem_end = last_problem_end + num_tiles + + +def triton_perf_fn(group_A, group_B, group_C, dtype): + # We put the process of matrix lengths and pointers here out of fairness, + # since cublas_grouped_gemm kernel also does these work. + group_size = len(group_A) + A_addrs = [] + B_addrs = [] + C_addrs = [] + g_sizes = [] + g_lds = [] + alphas = [1.0] * group_size + betas = [0.0] * group_size + for i in range(group_size): + M, N, K = group_A[i].shape[0], group_B[i].shape[1], group_A[i].shape[1] + g_sizes += [M, N, K] + g_lds += [K, N, N] + A_addrs.append(group_A[i].data_ptr()) + B_addrs.append(group_B[i].data_ptr()) + C_addrs.append(group_C[i].data_ptr()) + + d_a_ptrs = torch.tensor(A_addrs, device="cuda") + d_b_ptrs = torch.tensor(B_addrs, device="cuda") + d_c_ptrs = torch.tensor(C_addrs, device="cuda") + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device="cuda") + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device="cuda") + d_alphas = torch.tensor(alphas, dtype=torch.float32, device="cuda") + d_betas = torch.tensor(betas, dtype=torch.float32, device="cuda") + + NUM_SM = 128 + grid = (NUM_SM,) + grouped_matmul_kernel[grid]( + d_a_ptrs, + d_b_ptrs, + d_c_ptrs, + d_g_sizes, + d_g_lds, + d_alphas, + d_betas, + group_size, + NUM_SM=NUM_SM, + BLOCK_SIZE_M=128, + BLOCK_SIZE_N=128, + BLOCK_SIZE_K=32, + ) + + +def cublas_perf_fn(group_A, group_B, group_C, dtype): + cublas_grouped_gemm(group_A, group_B, group_C, dtype) + + +@triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M"], + x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048], + x_log=False, + line_arg="provider", + line_vals=[ + "triton", + "cublas", + ], + line_names=[ + "triton", + "cublas", + ], + styles=[("green", "-"), ("blue", "-")], + ylabel="gbps", + plot_name="grouped gemm", + args={}, + ) +) +def benchmark(M, provider, N, K): + group_size = 20 # Number of used experts per gpu is usually around 20 + group_A = [] + group_B_row_major = [] + group_B_col_major = [] + group_C = [] + dtype = torch.float16 + for i in range(group_size): + A = torch.rand((M, K), device="cuda", dtype=dtype) + B_row_major = torch.rand((K, N), device="cuda", dtype=dtype) + B_col_major = torch.rand((N, K), device="cuda", dtype=dtype) + C = torch.empty((M, N), device="cuda", dtype=dtype) + group_A.append(A) + group_B_row_major.append(B_row_major) + group_B_col_major.append(B_col_major) + group_C.append(C) + + quantiles = [0.5, 0.2, 0.8] + if "triton" in provider: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_perf_fn(group_A, group_B_row_major, group_C, dtype), + quantiles=quantiles, + ) + elif "cublas" in provider: + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: cublas_perf_fn(group_A, group_B_col_major, group_C, dtype), + quantiles=quantiles, + ) + + gbps = ( + lambda ms: group_size + * (2 * M * N * K + 2 * M * N) + * group_A[0].element_size() + * 1e-9 + / (ms * 1e-3) + ) + return gbps(ms), gbps(max_ms), gbps(min_ms) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--models", + nargs="+", + type=str, + default=["DeepSeek-V2"], + help="List of models to benchmark", + ) + parser.add_argument( + "--tp-size", + type=int, + default=8, + help="Tensor parallel size", + ) + args = parser.parse_args() + for model in args.models: + assert model in WEIGHT_CONFIGS + num_experts_per_device = ( + WEIGHT_CONFIGS[model]["num_routed_experts"] // args.tp_size + ) + for K, N in WEIGHT_CONFIGS[model]["ffn_shapes"]: + print( + f"{model} N={N} K={K} tp_size={args.tp_size} " + f"group_size=num_experts_per_device={num_experts_per_device}: " + ) + benchmark.run( + print_data=True, + show_plots=True, + save_path="bench_grouped_gemm_res", + N=N, + K=K, + ) + + print("Benchmark finished!") diff --git a/sgl-kernel/setup.py b/sgl-kernel/setup.py index b3b67ec51..532f90601 100644 --- a/sgl-kernel/setup.py +++ b/sgl-kernel/setup.py @@ -102,6 +102,7 @@ sources = [ "src/sgl-kernel/csrc/eagle_utils.cu", "src/sgl-kernel/csrc/speculative_sampling.cu", "src/sgl-kernel/csrc/per_token_group_quant_fp8.cu", + "src/sgl-kernel/csrc/cublas_grouped_gemm.cu", "3rdparty/flashinfer/csrc/activation.cu", "3rdparty/flashinfer/csrc/bmm_fp8.cu", "3rdparty/flashinfer/csrc/norm.cu", diff --git a/sgl-kernel/src/sgl-kernel/__init__.py b/sgl-kernel/src/sgl-kernel/__init__.py index f8e532164..25ed6bb74 100644 --- a/sgl-kernel/src/sgl-kernel/__init__.py +++ b/sgl-kernel/src/sgl-kernel/__init__.py @@ -12,6 +12,7 @@ from sgl_kernel.ops import ( bmm_fp8, build_tree_kernel, build_tree_kernel_efficient, + cublas_grouped_gemm, custom_dispose, custom_reduce, fp8_blockwise_scaled_mm, @@ -43,6 +44,7 @@ from .version import __version__ __all__ = [ "apply_rope_with_cos_sin_cache_inplace", "bmm_fp8", + "cublas_grouped_gemm", "custom_dispose", "custom_reduce", "fp8_blockwise_scaled_mm", diff --git a/sgl-kernel/src/sgl-kernel/csrc/cublas_grouped_gemm.cu b/sgl-kernel/src/sgl-kernel/csrc/cublas_grouped_gemm.cu new file mode 100644 index 000000000..ec899d330 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/csrc/cublas_grouped_gemm.cu @@ -0,0 +1,148 @@ +// References: +// https://docs.nvidia.com/cuda/cublas/index.html#cublasgemmgroupedbatchedex +// https://github.com/NVIDIA/CUDALibrarySamples/blob/master/cuBLAS/Extensions/GemmGroupedBatchedEx/cublas_GemmGroupedBatchedEx_example.cu +// https://github.com/zhihu/ZhiLight/blob/main/src/nn/linear/gemm_grouped.cpp + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "utils.h" + +static void check_group_count(const std::vector& inputs, const std::vector& weights, + const std::vector& outputs) { + TORCH_CHECK(((inputs.size() == weights.size()) && (inputs.size() == outputs.size())), + "The group count of inputs, weights and outputs should be the same."); +} + +static void check_device_dtype(const torch::Dtype& dtype, const std::vector& tensors) { + for (const auto& t : tensors) { + TORCH_CHECK(dtype == t.dtype(), "dtype of all the tensors should be the same"); + TORCH_CHECK(t.is_cuda(), "All tensors should be in Cuda memory"); + } +} + +static std::vector get_dims(const std::vector& tensors, int dim) { + std::vector results; + for (const auto& t : tensors) { + TORCH_CHECK(t.dim() == 2, "Should pass in 2D matrices"); + results.push_back(t.size(dim)); + } + return std::move(results); +} + +static std::vector get_strides(const std::vector& tensors, int dim) { + std::vector results; + for (const auto& t : tensors) { + results.push_back(t.stride(dim)); + } + return std::move(results); +} + +static void check_equal(const std::vector& a, const std::vector& b, const std::string& err_msg) { + for (int i = 0; i < a.size(); ++i) { + TORCH_CHECK(a[i] == b[i], err_msg); + } +} + +static std::vector get_tensor_ptrs(const std::vector& tensors) { + std::vector ptrs; + for (auto& t : tensors) { + ptrs.push_back(t.data_ptr()); + } + return std::move(ptrs); +} + +static torch::Tensor create_ptr_pointer(const std::vector& ptrs, cudaStream_t stream) { + auto options = torch::TensorOptions().dtype(torch::kDouble).device(torch::kCUDA); + torch::Tensor gpu_ptrs = torch::empty({static_cast(ptrs.size())}, options); + TORCH_CHECK(cudaMemcpyAsync(gpu_ptrs.data_ptr(), ptrs.data(), sizeof(void*) * ptrs.size(), cudaMemcpyHostToDevice, + stream) == CUBLAS_STATUS_SUCCESS); + return gpu_ptrs; +} + +// We want compute input @ weight^T in row major +// This is equivalent to computing weight @ input^T in col major +// Cublas only accepts matrix in column major, so this arrangement is needed +void cublas_grouped_gemm(const std::vector& inputs, // b: (m, k) row major = (k, m) col major + const std::vector& weights, // a: (n, k) row major = (n, k)^T col major + const std::vector& outputs, // c: (m, n) row major = (n, m) col major + const torch::Dtype& out_dtype, int64_t cublas_handle, int64_t cuda_stream) { + TORCH_CHECK(out_dtype == torch::kHalf || out_dtype == torch::kBFloat16, + "cublas grouped_gemm can" + "only be applied to float16 and bfloat16 dtype"); + + int group_count = inputs.size(); + check_group_count(inputs, weights, outputs); + std::vector group_size(group_count, 1); + + // Make sure all tensors are on cuda and use the same dtype + check_device_dtype(out_dtype, inputs); + check_device_dtype(out_dtype, weights); + check_device_dtype(out_dtype, outputs); + cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF); + + // Weights should be transposed to (n, k) of column major + std::vector transa_array(group_count, CUBLAS_OP_T); + std::vector transb_array(group_count, CUBLAS_OP_N); + + // Get dim arrays + std::vector m_array = get_dims(weights, 0); + std::vector n_array = get_dims(inputs, 0); + std::vector k_array = get_dims(inputs, 1); + + // Make sure the dimensions in each group match + std::vector m_array1 = get_dims(outputs, 1); + std::vector n_array1 = get_dims(outputs, 0); + std::vector k_array1 = get_dims(weights, 1); + check_equal(m_array, m_array1, "sizes don't match on m dimension"); + check_equal(n_array, n_array1, "sizes don't match on n dimension"); + check_equal(k_array, k_array1, "sizes don't match on k dimension"); + + // Get leading dimensions + std::vector lda_array = get_strides(weights, 0); + std::vector ldb_array = get_strides(inputs, 0); + std::vector ldc_array = get_strides(outputs, 0); + + // Use default scaling factors + std::vector alpha_array(group_count, 1); + std::vector beta_array(group_count, 0); + + std::vector a_array = get_tensor_ptrs(weights); + std::vector b_array = get_tensor_ptrs(inputs); + std::vector c_array = get_tensor_ptrs(outputs); + + auto handle = reinterpret_cast(cublas_handle); + auto stream = reinterpret_cast(cuda_stream); + + // Should allocate tensors for storage of pointers + torch::Tensor d_a = create_ptr_pointer(a_array, stream); + torch::Tensor d_b = create_ptr_pointer(b_array, stream); + torch::Tensor d_c = create_ptr_pointer(c_array, stream); + +#if defined CUDA_VERSION && CUDA_VERSION >= 12050 + auto status = cublasGemmGroupedBatchedEx(handle, transa_array.data(), transb_array.data(), m_array.data(), + n_array.data(), k_array.data(), alpha_array.data(), (void**)d_a.data_ptr(), + cuda_data_type, lda_array.data(), (void**)d_b.data_ptr(), cuda_data_type, + ldb_array.data(), beta_array.data(), (void**)d_c.data_ptr(), cuda_data_type, + ldc_array.data(), group_count, group_size.data(), CUBLAS_COMPUTE_32F); + TORCH_CHECK(status == CUBLAS_STATUS_SUCCESS, "cublas grouped gemm failed: ", cublasGetStatusString(status)); + TORCH_CHECK(cudaStreamSynchronize(stream) == cudaSuccess, "Failed when stream synchronization"); + return; +#endif + + TORCH_CHECK_NOT_IMPLEMENTED(false, + "Cublas GroupGemm is not implemented with current compute capability: ", getSMVersion()); +} diff --git a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h index 97a4ba593..92a68d622 100644 --- a/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h +++ b/sgl-kernel/src/sgl-kernel/include/sgl_kernels_ops.h @@ -152,3 +152,8 @@ void build_tree_kernel(at::Tensor parent_list, at::Tensor selected_index, at::Te // sgl_per_token_group_quant_fp8 void sgl_per_token_group_quant_fp8(at::Tensor input, at::Tensor output_q, at::Tensor output_s, int64_t group_size, double eps, double fp8_min, double fp8_max); + +// cublas grouped gemm +void cublas_grouped_gemm(const std::vector& inputs, const std::vector& weights, + const std::vector& outputs, const torch::Dtype& out_dtype, + int64_t cublas_handle, int64_t cuda_stream); diff --git a/sgl-kernel/src/sgl-kernel/ops/__init__.py b/sgl-kernel/src/sgl-kernel/ops/__init__.py index 0a3ea9aca..9848323ad 100644 --- a/sgl-kernel/src/sgl-kernel/ops/__init__.py +++ b/sgl-kernel/src/sgl-kernel/ops/__init__.py @@ -1,5 +1,5 @@ import os -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import sgl_kernel.ops._kernels import torch @@ -603,3 +603,24 @@ def sgl_per_token_group_quant_fp8( torch.ops.sgl_kernels.sgl_per_token_group_quant_fp8( input, output_q, output_s, group_size, eps, fp8_min, fp8_max ) + + +def cublas_grouped_gemm( + inputs: List[torch.Tensor], + weights: List[torch.Tensor], + outputs: List[torch.Tensor], + out_dtype: torch.dtype, +) -> None: + with inputs[0].device as device: + assert ( + len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0 + ), "Inputs/weights/outputs should not be empty!" + cublas_handle = torch.cuda.current_blas_handle() + torch.ops.sgl_kernels.cublas_grouped_gemm( + inputs, + weights, + outputs, + out_dtype, + cublas_handle, + _get_cuda_stream(device), + ) diff --git a/sgl-kernel/src/sgl-kernel/torch_extension.cc b/sgl-kernel/src/sgl-kernel/torch_extension.cc index 861c98e29..0585911c4 100644 --- a/sgl-kernel/src/sgl-kernel/torch_extension.cc +++ b/sgl-kernel/src/sgl-kernel/torch_extension.cc @@ -165,6 +165,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernels, m) { "sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size," " float eps, float fp8_min, float fp8_max) -> ()"); m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8); + + // cublas grouped gemm + m.def( + "cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs," + " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()"); + m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm); } REGISTER_EXTENSION(_kernels) diff --git a/sgl-kernel/tests/test_cublas_grouped_gemm.py b/sgl-kernel/tests/test_cublas_grouped_gemm.py new file mode 100644 index 000000000..9aac569f2 --- /dev/null +++ b/sgl-kernel/tests/test_cublas_grouped_gemm.py @@ -0,0 +1,49 @@ +import unittest + +import torch +from sgl_kernel import cublas_grouped_gemm + + +def torch_grouped_gemm(a_array, b_array, out_dtype): + c_array = [] + for a, b in zip(a_array, b_array): + c_array.append(torch.matmul(a, b.t()).to(out_dtype)) + return c_array + + +class TestGroupedGemm(unittest.TestCase): + def _test_accuracy(self, Ms, Ns, Ks, out_dtype): + group_count = len(Ms) + a_array = [] + b_array = [] + c_array_cublas = [] + for i in range(group_count): + M, N, K = Ms[i], Ns[i], Ks[i] + a_array.append(torch.randn((M, K), device="cuda", dtype=out_dtype) * 5) + b_array.append(torch.randn((N, K), device="cuda", dtype=out_dtype) * 5) + c_array_cublas.append(torch.empty((M, N), device="cuda", dtype=out_dtype)) + + c_array_torch = torch_grouped_gemm(a_array, b_array, out_dtype) + cublas_grouped_gemm(a_array, b_array, c_array_cublas, out_dtype) + + for i in range(group_count): + M, N, K = Ms[i], Ns[i], Ks[i] + torch.testing.assert_close(c_array_torch[i], c_array_cublas[i]) + print(f"M={M}, N={N}, K={K}, out_dtype={out_dtype}: OK") + + def test_accuracy(self): + Ms = [1, 16, 32, 256, 1024] + Ns = [2, 16, 128, 256, 4096] + Ks = [3, 16, 32, 512, 8192] + out_dtypes = [torch.float16, torch.bfloat16] + for out_dtype in out_dtypes: + self._test_accuracy(Ms, Ns, Ks, out_dtype) + + +if __name__ == "__main__": + if torch.cuda.is_available(): + cuda_version = tuple(map(int, torch.version.cuda.split("."))) + if cuda_version >= (12, 5): + unittest.main() + else: + print(f"Cuda version {cuda_version} lower than 12.5, not executing tests.")