Simplify tests & Fix trtllm custom allreduce registration (#4252)

This commit is contained in:
Lianmin Zheng
2025-03-10 01:24:22 -07:00
committed by GitHub
parent 007f8b3dc2
commit aa957102a9
13 changed files with 30 additions and 211 deletions

View File

@@ -100,7 +100,6 @@ void cublas_grouped_gemm(
check_device_dtype(out_dtype, inputs);
check_device_dtype(out_dtype, weights);
check_device_dtype(out_dtype, outputs);
cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF);
// Weights should be transposed to (n, k) of column major
std::vector<cublasOperation_t> transa_array(group_count, CUBLAS_OP_T);
@@ -132,7 +131,6 @@ void cublas_grouped_gemm(
std::vector<void*> b_array = get_tensor_ptrs(inputs);
std::vector<void*> c_array = get_tensor_ptrs(outputs);
auto handle = reinterpret_cast<cublasHandle_t>(cublas_handle);
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);
// Should allocate tensors for storage of pointers
@@ -141,6 +139,9 @@ void cublas_grouped_gemm(
torch::Tensor d_c = create_ptr_pointer(c_array, stream);
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
auto handle = reinterpret_cast<cublasHandle_t>(cublas_handle);
cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF);
auto status = cublasGemmGroupedBatchedEx(
handle,
transa_array.data(),

View File

@@ -32,11 +32,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
m.impl("all_reduce", torch::kCUDA, &all_reduce);
m.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])");
m.impl("get_graph_buffer_ipc_meta", torch::kCUDA, &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
m.impl("register_graph_buffers", torch::kCUDA, &register_graph_buffers);
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
m.def("register_graph_buffers", &register_graph_buffers);
/*
* From csrc/attention

View File

@@ -1,9 +1,7 @@
import math
from typing import Any, Dict, List, Optional, Tuple, Union
import pytest
import torch
import torch.nn as nn
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace