diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index b74fefb77..095ad47f7 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -247,6 +247,7 @@ set(SOURCES "csrc/attention/vertical_slash_index.cu" "csrc/elementwise/activation.cu" "csrc/elementwise/cast.cu" + "csrc/elementwise/copy.cu" "csrc/elementwise/fused_add_rms_norm_kernel.cu" "csrc/elementwise/rope.cu" "csrc/common_extension.cc" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 54587b1be..18a141af1 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -445,6 +445,9 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "qserve_w4a8_per_group_gemm(Tensor _in_feats, Tensor _kernel, Tensor _zeros, Tensor _scales_i8, Tensor _wscales, " "Tensor _ascales, Tensor! _out_feats) -> ()"); m.impl("qserve_w4a8_per_group_gemm", torch::kCUDA, &qserve_w4a8_per_group_gemm); + + m.def("copy_to_gpu_no_ce(Tensor input, Tensor! output) -> ()"); + m.impl("copy_to_gpu_no_ce", torch::kCUDA, ©_to_gpu_no_ce); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/elementwise/copy.cu b/sgl-kernel/csrc/elementwise/copy.cu new file mode 100644 index 000000000..09719f510 --- /dev/null +++ b/sgl-kernel/csrc/elementwise/copy.cu @@ -0,0 +1,58 @@ +#include +#include +#include + +#include + +template +struct InputArray { + int values[N]; +}; + +template +__global__ void copy_to_gpu_no_ce_kernel(const InputArray input_array, int* output) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < N) { + output[idx] = input_array.values[idx]; + } +} + +template +void copy_to_gpu_no_ce_impl(const at::Tensor& input, at::Tensor& output) { + TORCH_CHECK(input.dim() == 1, "input must be 1-D"); + TORCH_CHECK(static_cast(input.numel()) == N, "input numel must equal template N"); + TORCH_CHECK(input.is_contiguous(), "input must be contiguous"); + TORCH_CHECK(input.dtype() == torch::kInt32, "input dtype must be int32"); + + TORCH_CHECK(output.dim() == 1, "output dim"); + TORCH_CHECK(static_cast(output.numel()) == N, "output size"); + TORCH_CHECK(output.is_contiguous(), "output contiguous"); + TORCH_CHECK(output.dtype() == torch::kInt32, "output dtype"); + + TORCH_CHECK(input.device().is_cpu(), "input must be a CPU tensor"); + TORCH_CHECK(output.device().is_cuda(), "output must be a CUDA tensor"); + + InputArray input_array; + const int* input_ptr = input.data_ptr(); + for (int i = 0; i < N; ++i) + input_array.values[i] = input_ptr[i]; + + // may use multi thread blocks if performance bottleneck + dim3 grid(1); + dim3 block(static_cast(input.numel())); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + copy_to_gpu_no_ce_kernel<<>>(input_array, output.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output) { + int N = static_cast(input.numel()); + // Can use macro if there are more N needed + if (N == 72) { + copy_to_gpu_no_ce_impl<72>(input, output); + } else if (N == 64) { + copy_to_gpu_no_ce_impl<64>(input, output); + } else { + TORCH_CHECK(false, "unexpected N"); + } +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index b6c40c801..0b4b979ab 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -750,3 +750,5 @@ std::vector create_greenctx_stream_by_value(int64_t smA, int64_t smB, i * From csrc/memory */ void store_kv_cache(at::Tensor k_cache, at::Tensor v_cache, at::Tensor out_loc, at::Tensor k, at::Tensor v); + +void copy_to_gpu_no_ce(const at::Tensor& input, at::Tensor& output); diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index cf771d553..0476ad696 100755 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -23,6 +23,7 @@ from sgl_kernel.cutlass_moe import cutlass_w4a8_moe_mm, get_cutlass_w4a8_moe_mm_ from sgl_kernel.elementwise import ( FusedSetKVBufferArg, apply_rope_with_cos_sin_cache_inplace, + copy_to_gpu_no_ce, downcast_fp8, fused_add_rmsnorm, gelu_and_mul, diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 9abfe8384..863b4d97e 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import torch from sgl_kernel.utils import get_cuda_stream, is_arch_support_pdl @@ -367,3 +367,7 @@ def downcast_fp8( torch.ops.sgl_kernel.downcast_fp8( k, v, k_out, v_out, k_scale, v_scale, loc, mult, offset, get_cuda_stream() ) + + +def copy_to_gpu_no_ce(input: List[int], output: torch.Tensor): + torch.ops.sgl_kernel.copy_to_gpu_no_ce(input, output)