Support copying tensor from cpu to gpu without using copy engines (#10007)
This commit is contained in:
@@ -247,6 +247,7 @@ set(SOURCES
|
||||
"csrc/attention/vertical_slash_index.cu"
|
||||
"csrc/elementwise/activation.cu"
|
||||
"csrc/elementwise/cast.cu"
|
||||
"csrc/elementwise/copy.cu"
|
||||
"csrc/elementwise/fused_add_rms_norm_kernel.cu"
|
||||
"csrc/elementwise/rope.cu"
|
||||
"csrc/common_extension.cc"
|
||||
|
||||
@@ -445,6 +445,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, "
|
||||
"Tensor _ascales, Tensor! _out_feats) -> ()");
|
||||
m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm);
|
||||
|
||||
m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()");
|
||||
m.impl("copy_to_gpu_no_ce", torch::kCUDA, ©_to_gpu_no_ce);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
58
sgl-kernel/csrc/elementwise/copy.cu
Normal file
58
sgl-kernel/csrc/elementwise/copy.cu
Normal file
@@ -0,0 +1,58 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
template <int N>
|
||||
struct InputArray {
|
||||
int values[N];
|
||||
};
|
||||
|
||||
template <int N>
|
||||
__global__ void copy_to_gpu_no_ce_kernel(const InputArray<N> input_array, int* output) {
|
||||
int idx = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
if (idx < N) {
|
||||
output[idx] = input_array.values[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
void copy_to_gpu_no_ce_impl(const at::Tensor& input, at::Tensor& output) {
|
||||
TORCH_CHECK(input.dim() == 1, "input must be 1-D");
|
||||
TORCH_CHECK(static_cast<int>(input.numel()) == N, "input numel must equal template N");
|
||||
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
|
||||
TORCH_CHECK(input.dtype() == torch::kInt32, "input dtype must be int32");
|
||||
|
||||
TORCH_CHECK(output.dim() == 1, "output dim");
|
||||
TORCH_CHECK(static_cast<int>(output.numel()) == N, "output size");
|
||||
TORCH_CHECK(output.is_contiguous(), "output contiguous");
|
||||
TORCH_CHECK(output.dtype() == torch::kInt32, "output dtype");
|
||||
|
||||
TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor");
|
||||
TORCH_CHECK(output.device().is_cuda(), "output must be a CUDA tensor");
|
||||
|
||||
InputArray<N> input_array;
|
||||
const int* input_ptr = input.data_ptr<int>();
|
||||
for (int i = 0; i < N; ++i)
|
||||
input_array.values[i] = input_ptr[i];
|
||||
|
||||
// may use multi thread blocks if performance bottleneck
|
||||
dim3 grid(1);
|
||||
dim3 block(static_cast<int>(input.numel()));
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
copy_to_gpu_no_ce_kernel<<<grid, block, 0, stream>>>(input_array, output.data_ptr<int>());
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output) {
|
||||
int N = static_cast<int>(input.numel());
|
||||
// Can use macro if there are more N needed
|
||||
if (N == 72) {
|
||||
copy_to_gpu_no_ce_impl<72>(input, output);
|
||||
} else if (N == 64) {
|
||||
copy_to_gpu_no_ce_impl<64>(input, output);
|
||||
} else {
|
||||
TORCH_CHECK(false, "unexpected N");
|
||||
}
|
||||
}
|
||||
@@ -750,3 +750,5 @@ std::vector<int64_t> create_greenctx_stream_by_value(int64_t smA, int64_t smB, i
|
||||
* From csrc/memory
|
||||
*/
|
||||
void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v);
|
||||
|
||||
void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output);
|
||||
|
||||
@@ -23,6 +23,7 @@ from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_
|
||||
from sgl_kernel.elementwise import (
|
||||
FusedSetKVBufferArg,
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
copy_to_gpu_no_ce,
|
||||
downcast_fp8,
|
||||
fused_add_rmsnorm,
|
||||
gelu_and_mul,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream, is_arch_support_pdl
|
||||
@@ -367,3 +367,7 @@ def downcast_fp8(
|
||||
torch.ops.sgl_kernel.downcast_fp8(
|
||||
k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream()
|
||||
)
|
||||
|
||||
|
||||
def copy_to_gpu_no_ce(input: List[int], output: torch.Tensor):
|
||||
torch.ops.sgl_kernel.copy_to_gpu_no_ce(input, output)
|
||||
|
||||
Reference in New Issue
Block a user