Simplify tests & Fix trtllm custom allreduce registration (#4252)
This commit is contained in:
@@ -100,7 +100,6 @@ void cublas_grouped_gemm(
|
||||
check_device_dtype(out_dtype, inputs);
|
||||
check_device_dtype(out_dtype, weights);
|
||||
check_device_dtype(out_dtype, outputs);
|
||||
cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF);
|
||||
|
||||
// Weights should be transposed to (n, k) of column major
|
||||
std::vector<cublasOperation_t> transa_array(group_count, CUBLAS_OP_T);
|
||||
@@ -132,7 +131,6 @@ void cublas_grouped_gemm(
|
||||
std::vector<void*> b_array = get_tensor_ptrs(inputs);
|
||||
std::vector<void*> c_array = get_tensor_ptrs(outputs);
|
||||
|
||||
auto handle = reinterpret_cast<cublasHandle_t>(cublas_handle);
|
||||
auto stream = reinterpret_cast<cudaStream_t>(cuda_stream);
|
||||
|
||||
// Should allocate tensors for storage of pointers
|
||||
@@ -141,6 +139,9 @@ void cublas_grouped_gemm(
|
||||
torch::Tensor d_c = create_ptr_pointer(c_array, stream);
|
||||
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12050
|
||||
auto handle = reinterpret_cast<cublasHandle_t>(cublas_handle);
|
||||
cudaDataType_t cuda_data_type = (out_dtype == torch::kHalf ? CUDA_R_16F : CUDA_R_16BF);
|
||||
|
||||
auto status = cublasGemmGroupedBatchedEx(
|
||||
handle,
|
||||
transa_array.data(),
|
||||
|
||||
Reference in New Issue
Block a user