[AMD] Support Hierarchical Caching on AMD GPUs (#8236)
This commit is contained in:
@@ -121,6 +121,48 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
*/
|
||||
m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()");
|
||||
m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace);
|
||||
|
||||
/*
|
||||
* From csrc/kvcacheio
|
||||
*/
|
||||
m.def(
|
||||
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
|
||||
"dst_indices, int layer_id, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf);
|
||||
m.def(
|
||||
"transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, "
|
||||
"Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int "
|
||||
"num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, "
|
||||
"Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int "
|
||||
"num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer_lf_pf", torch::kCUDA, &transfer_kv_all_layer_lf_pf);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int "
|
||||
"block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla);
|
||||
m.def(
|
||||
"transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int layer_id, "
|
||||
"int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int "
|
||||
"item_size, int num_layers, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla);
|
||||
m.def(
|
||||
"transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, "
|
||||
"int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()");
|
||||
m.impl("transfer_kv_all_layer_mla_lf_pf", torch::kCUDA, &transfer_kv_all_layer_mla_lf_pf);
|
||||
m.def(
|
||||
"transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int "
|
||||
"page_size) -> ()");
|
||||
m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
@@ -4,21 +4,31 @@
|
||||
|
||||
#include <cstdint>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#include "pytorch_extension_utils.h"
|
||||
#else
|
||||
#include "pytorch_extension_utils_rocm.h"
|
||||
#include "utils.h" // WARP_SIZE
|
||||
#endif
|
||||
|
||||
__device__ __forceinline__ void
|
||||
transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) {
|
||||
// todo, different chunk size
|
||||
int total_chunks = item_size_bytes / 8;
|
||||
const int64_t* src_8 = reinterpret_cast<const int64_t*>(src_addr);
|
||||
int64_t* dst_8 = reinterpret_cast<int64_t*>(dst_addr);
|
||||
const uint64_t* __restrict__ src = static_cast<const uint64_t*>(src_addr);
|
||||
uint64_t* __restrict__ dst = static_cast<uint64_t*>(dst_addr);
|
||||
const int total_chunks = item_size_bytes / sizeof(uint64_t);
|
||||
|
||||
#pragma unroll
|
||||
for (int j = lane_id; j < total_chunks; j += 32) {
|
||||
const int64_t* src_addr_lane = &src_8[j];
|
||||
int64_t* dst_addr_lane = &dst_8[j];
|
||||
int64_t temp_val;
|
||||
asm volatile("ld.global.nc.b64 %0, [%1];" : "=l"(temp_val) : "l"(src_addr_lane) : "memory");
|
||||
asm volatile("st.global.cg.b64 [%0], %1;" ::"l"(dst_addr_lane), "l"(temp_val) : "memory");
|
||||
for (int j = lane_id; j < total_chunks; j += WARP_SIZE) {
|
||||
#ifndef USE_ROCM
|
||||
uint64_t tmp;
|
||||
asm volatile("ld.global.nc.b64 %0,[%1];" : "=l"(tmp) : "l"(src + j) : "memory");
|
||||
asm volatile("st.global.cg.b64 [%0],%1;" ::"l"(dst + j), "l"(tmp) : "memory");
|
||||
|
||||
#else
|
||||
uint64_t tmp = __builtin_nontemporal_load(src + j);
|
||||
__builtin_nontemporal_store(tmp, dst + j);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,8 +88,8 @@ __global__ void transfer_kernel_impl(
|
||||
const uintptr_t* __restrict__ src_v_layer_tbl,
|
||||
const uintptr_t* __restrict__ dst_v_layer_tbl) {
|
||||
int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int32_t lane_id = tid % 32;
|
||||
int32_t warp_id = tid / 32;
|
||||
int32_t lane_id = tid % WARP_SIZE;
|
||||
int32_t warp_id = tid / WARP_SIZE;
|
||||
|
||||
for (int i = 0; i < items_per_warp; ++i) {
|
||||
int64_t item_id = warp_id * items_per_warp + i;
|
||||
@@ -139,7 +149,7 @@ void transfer_kv_launcher(
|
||||
const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block);
|
||||
const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block);
|
||||
dim3 grid_dim(num_blocks, 1, 1);
|
||||
const int32_t threads_per_block = num_warps_per_block * 32;
|
||||
const int32_t threads_per_block = num_warps_per_block * WARP_SIZE;
|
||||
|
||||
const void* src_k_ptr = src_k.defined() ? src_k.data_ptr() : nullptr;
|
||||
void* dst_k_ptr = dst_k.defined() ? dst_k.data_ptr() : nullptr;
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
#include <torch/library.h>
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
Reference in New Issue
Block a user