Simplify tests & Fix trtllm custom allreduce registration (#4252)
This commit is contained in:
@@ -100,7 +100,6 @@ void cublas_grouped_gemm(
|
||||
check_device_dtype(out_dtype, inputs);
|
||||
check_device_dtype(out_dtype, weights);
|
||||
check_device_dtype(out_dtype, outputs);
|
||||
cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF);
|
||||
|
||||
// Weights should be transposed to (n, k) of column major
|
||||
std::vector<cublasOperation_t> transa_array(group_count, CUBLAS_OP_T);
|
||||
@@ -132,7 +131,6 @@ void cublas_grouped_gemm(
|
||||
std::vector<void*> b_array = get_tensor_ptrs(inputs);
|
||||
std::vector<void*> c_array = get_tensor_ptrs(outputs);
|
||||
|
||||
auto handle = reinterpret_cast<cublasHandle_t>(cublas_handle);
|
||||
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
|
||||
// Should allocate tensors for storage of pointers
|
||||
@@ -141,6 +139,9 @@ void cublas_grouped_gemm(
|
||||
torch::Tensor d_c = create_ptr_pointer(c_array, stream);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
||||
auto handle = reinterpret_cast<cublasHandle_t>(cublas_handle);
|
||||
cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF);
|
||||
|
||||
auto status = cublasGemmGroupedBatchedEx(
|
||||
handle,
|
||||
transa_array.data(),
|
||||
|
||||
@@ -32,11 +32,8 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
|
||||
m.impl("all_reduce", torch::kCUDA, &all_reduce);
|
||||
|
||||
m.def("get_graph_buffer_ipc_meta(int fa) -> (int[], int[])");
|
||||
m.impl("get_graph_buffer_ipc_meta", torch::kCUDA, &get_graph_buffer_ipc_meta);
|
||||
|
||||
m.def("register_graph_buffers(int fa, int[][] handles, int[][] offsets) -> ()");
|
||||
m.impl("register_graph_buffers", torch::kCUDA, ®ister_graph_buffers);
|
||||
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
|
||||
m.def("register_graph_buffers", ®ister_graph_buffers);
|
||||
|
||||
/*
|
||||
* From csrc/attention
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user