[AMD] Support Hierarchical Caching on AMD GPUs (#8236)

This commit is contained in:
Hubert Lu
2025-08-28 15:27:07 -07:00
committed by GitHub
parent 5343058875
commit 711390a971
10 changed files with 105 additions and 32 deletions

View File

@@ -4,21 +4,31 @@
#include <cstdint>
#ifndef USE_ROCM
#define WARP_SIZE 32
#include "pytorch_extension_utils.h"
#else
#include "pytorch_extension_utils_rocm.h"
#include "utils.h" // WARP_SIZE
#endif
__device__ __forceinline__ void
transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) {
// todo, different chunk size
int total_chunks = item_size_bytes / 8;
const int64_t* src_8 = reinterpret_cast<const int64_t*>(src_addr);
int64_t* dst_8 = reinterpret_cast<int64_t*>(dst_addr);
const uint64_t* __restrict__ src = static_cast<const uint64_t*>(src_addr);
uint64_t* __restrict__ dst = static_cast<uint64_t*>(dst_addr);
const int total_chunks = item_size_bytes / sizeof(uint64_t);
#pragma unroll
for (int j = lane_id; j < total_chunks; j += 32) {
const int64_t* src_addr_lane = &src_8[j];
int64_t* dst_addr_lane = &dst_8[j];
int64_t temp_val;
asm volatile("ld.global.nc.b64 %0, [%1];" : "=l"(temp_val) : "l"(src_addr_lane) : "memory");
asm volatile("st.global.cg.b64 [%0], %1;" ::"l"(dst_addr_lane), "l"(temp_val) : "memory");
for (int j = lane_id; j < total_chunks; j += WARP_SIZE) {
#ifndef USE_ROCM
uint64_t tmp;
asm volatile("ld.global.nc.b64 %0,[%1];" : "=l"(tmp) : "l"(src + j) : "memory");
asm volatile("st.global.cg.b64 [%0],%1;" ::"l"(dst + j), "l"(tmp) : "memory");
#else
uint64_t tmp = __builtin_nontemporal_load(src + j);
__builtin_nontemporal_store(tmp, dst + j);
#endif
}
}
@@ -78,8 +88,8 @@ __global__ void transfer_kernel_impl(
const uintptr_t* __restrict__ src_v_layer_tbl,
const uintptr_t* __restrict__ dst_v_layer_tbl) {
int32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
int32_t lane_id = tid % 32;
int32_t warp_id = tid / 32;
int32_t lane_id = tid % WARP_SIZE;
int32_t warp_id = tid / WARP_SIZE;
for (int i = 0; i < items_per_warp; ++i) {
int64_t item_id = warp_id * items_per_warp + i;
@@ -139,7 +149,7 @@ void transfer_kv_launcher(
const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block);
const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block);
dim3 grid_dim(num_blocks, 1, 1);
const int32_t threads_per_block = num_warps_per_block * 32;
const int32_t threads_per_block = num_warps_per_block * WARP_SIZE;
const void* src_k_ptr = src_k.defined() ? src_k.data_ptr() : nullptr;
void* dst_k_ptr = dst_k.defined() ? dst_k.data_ptr() : nullptr;