diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index 7835b1ec0..ef88cf40e 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -223,7 +223,7 @@ jobs: fail-fast: false matrix: runner: [linux-mi300-gpu-1, linux-mi325-gpu-1] - part: [0, 1, 2, 3, 4, 5, 6] + part: [0, 1, 2, 3, 4, 5, 6, 7] runs-on: ${{matrix.runner}} steps: - name: Checkout code @@ -240,7 +240,7 @@ jobs: - name: Run test timeout-minutes: 50 run: | - bash scripts/ci/amd_ci_exec.sh python3 run_suite.py --suite per-commit-amd --auto-partition-id ${{ matrix.part }} --auto-partition-size 7 + bash scripts/ci/amd_ci_exec.sh python3 run_suite.py --suite per-commit-amd --auto-partition-id ${{ matrix.part }} --auto-partition-size 8 unit-test-backend-2-gpu-amd: if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && @@ -336,13 +336,14 @@ jobs: bash scripts/ci/amd_ci_install_dependency.sh - name: Run test - timeout-minutes: 10 + timeout-minutes: 14 run: | docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_align.py docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_topk_softmax.py docker exec -w /sglang-checkout/sgl-kernel/tests/speculative ci_sglang python3 -m pytest test_eagle_utils.py docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_apply_token_bitmask_inplace.py docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_activation.py + docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_kvcacheio.py pr-test-amd-finish: if: always() diff --git a/sgl-kernel/csrc/common_extension_rocm.cc b/sgl-kernel/csrc/common_extension_rocm.cc index e4eb9c68e..1f94d2615 100644 --- a/sgl-kernel/csrc/common_extension_rocm.cc +++ b/sgl-kernel/csrc/common_extension_rocm.cc @@ -121,6 +121,48 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { */ m.def("apply_token_bitmask_inplace_cuda(Tensor logits, Tensor bitmask, Tensor? indices=None) -> ()"); m.impl("apply_token_bitmask_inplace_cuda", &ApplyTokenBitmaskInplace); + + /* + * From csrc/kvcacheio + */ + m.def( + "transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " + "dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer); + m.def( + "transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " + "dst_indices, int layer_id, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf); + m.def( + "transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, " + "Tensor src_indices, Tensor dst_indices, int item_size, int num_layers, int block_quota, int " + "num_warps_per_block) -> ()"); + m.impl("transfer_kv_all_layer", torch::kCUDA, &transfer_kv_all_layer); + m.def( + "transfer_kv_all_layer_lf_pf(Tensor src_k_layers, Tensor dst_k, Tensor src_v_layers, Tensor dst_v, " + "Tensor src_indices, Tensor dst_indices, int item_size, int dst_layout_dim, int num_layers, int block_quota, int " + "num_warps_per_block) -> ()"); + m.impl("transfer_kv_all_layer_lf_pf", torch::kCUDA, &transfer_kv_all_layer_lf_pf); + m.def( + "transfer_kv_per_layer_mla(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, int " + "block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla); + m.def( + "transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int layer_id, " + "int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf); + m.def( + "transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int " + "item_size, int num_layers, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_all_layer_mla", torch::kCUDA, &transfer_kv_all_layer_mla); + m.def( + "transfer_kv_all_layer_mla_lf_pf(Tensor src_layers, Tensor dst, Tensor src_indices, Tensor dst_indices, " + "int item_size, int dst_layout_dim, int num_layers, int block_quota, int num_warps_per_block) -> ()"); + m.impl("transfer_kv_all_layer_mla_lf_pf", torch::kCUDA, &transfer_kv_all_layer_mla_lf_pf); + m.def( + "transfer_kv_direct(Tensor[] src_layers, Tensor[] dst_layers, Tensor src_indices, Tensor dst_indices, int " + "page_size) -> ()"); + m.impl("transfer_kv_direct", torch::kCUDA, &transfer_kv_direct); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/csrc/kvcacheio/transfer.cu b/sgl-kernel/csrc/kvcacheio/transfer.cu index cbf5feeea..fab0d3bb8 100644 --- a/sgl-kernel/csrc/kvcacheio/transfer.cu +++ b/sgl-kernel/csrc/kvcacheio/transfer.cu @@ -4,21 +4,31 @@ #include +#ifndef USE_ROCM +#define WARP_SIZE 32 #include "pytorch_extension_utils.h" +#else +#include "pytorch_extension_utils_rocm.h" +#include "utils.h" // WARP_SIZE +#endif __device__ __forceinline__ void transfer_item_warp(int32_t lane_id, const void* src_addr, void* dst_addr, int64_t item_size_bytes) { - // todo, different chunk size - int total_chunks = item_size_bytes / 8; - const int64_t* src_8 = reinterpret_cast(src_addr); - int64_t* dst_8 = reinterpret_cast(dst_addr); + const uint64_t* __restrict__ src = static_cast(src_addr); + uint64_t* __restrict__ dst = static_cast(dst_addr); + const int total_chunks = item_size_bytes / sizeof(uint64_t); + #pragma unroll - for (int j = lane_id; j < total_chunks; j += 32) { - const int64_t* src_addr_lane = &src_8[j]; - int64_t* dst_addr_lane = &dst_8[j]; - int64_t temp_val; - asm volatile("ld.global.nc.b64 %0, [%1];" : "=l"(temp_val) : "l"(src_addr_lane) : "memory"); - asm volatile("st.global.cg.b64 [%0], %1;" ::"l"(dst_addr_lane), "l"(temp_val) : "memory"); + for (int j = lane_id; j < total_chunks; j += WARP_SIZE) { +#ifndef USE_ROCM + uint64_t tmp; + asm volatile("ld.global.nc.b64 %0,[%1];" : "=l"(tmp) : "l"(src + j) : "memory"); + asm volatile("st.global.cg.b64 [%0],%1;" ::"l"(dst + j), "l"(tmp) : "memory"); + +#else + uint64_t tmp = __builtin_nontemporal_load(src + j); + __builtin_nontemporal_store(tmp, dst + j); +#endif } } @@ -78,8 +88,8 @@ __global__ void transfer_kernel_impl( const uintptr_t* __restrict__ src_v_layer_tbl, const uintptr_t* __restrict__ dst_v_layer_tbl) { int32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - int32_t lane_id = tid % 32; - int32_t warp_id = tid / 32; + int32_t lane_id = tid % WARP_SIZE; + int32_t warp_id = tid / WARP_SIZE; for (int i = 0; i < items_per_warp; ++i) { int64_t item_id = warp_id * items_per_warp + i; @@ -139,7 +149,7 @@ void transfer_kv_launcher( const int64_t items_per_warp = div_up(num_items, block_quota * num_warps_per_block); const int32_t num_blocks = div_up(num_items, items_per_warp * num_warps_per_block); dim3 grid_dim(num_blocks, 1, 1); - const int32_t threads_per_block = num_warps_per_block * 32; + const int32_t threads_per_block = num_warps_per_block * WARP_SIZE; const void* src_k_ptr = src_k.defined() ? src_k.data_ptr() : nullptr; void* dst_k_ptr = dst_k.defined() ? dst_k.data_ptr() : nullptr; diff --git a/sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h b/sgl-kernel/include/pytorch_extension_utils_rocm.h similarity index 100% rename from sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h rename to sgl-kernel/include/pytorch_extension_utils_rocm.h diff --git a/sgl-kernel/python/sgl_kernel/kvcacheio.py b/sgl-kernel/python/sgl_kernel/kvcacheio.py index fd05e8466..913cbc5e3 100644 --- a/sgl-kernel/python/sgl_kernel/kvcacheio.py +++ b/sgl-kernel/python/sgl_kernel/kvcacheio.py @@ -3,6 +3,13 @@ from typing import List import torch +def is_hip() -> bool: + return torch.version.hip is not None + + +_is_hip = is_hip() + + def transfer_kv_per_layer( src_k: torch.Tensor, dst_k: torch.Tensor, @@ -12,7 +19,7 @@ def transfer_kv_per_layer( dst_indices: torch.Tensor, item_size: int, block_quota: int = 2, - num_warps_per_block: int = 32, + num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_per_layer( src_k, @@ -38,7 +45,7 @@ def transfer_kv_per_layer_pf_lf( item_size: int, src_layout_dim: int, block_quota: int = 2, - num_warps_per_block: int = 32, + num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf( src_k, @@ -65,7 +72,7 @@ def transfer_kv_all_layer( item_size: int, num_layers: int, block_quota: int = 2, - num_warps_per_block: int = 32, + num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_all_layer( src_k_layers, @@ -92,7 +99,7 @@ def transfer_kv_all_layer_lf_pf( dst_layout_dim: int, num_layers: int, block_quota: int = 2, - num_warps_per_block: int = 32, + num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf( src_k_layers, @@ -128,7 +135,7 @@ def transfer_kv_per_layer_mla( dst_indices: torch.Tensor, item_size: int, block_quota: int = 2, - num_warps_per_block: int = 32, + num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_per_layer_mla( src, @@ -150,7 +157,7 @@ def transfer_kv_per_layer_mla_pf_lf( item_size: int, src_layout_dim: int, block_quota: int = 2, - num_warps_per_block: int = 32, + num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf( src, @@ -173,7 +180,7 @@ def transfer_kv_all_layer_mla( item_size: int, num_layers: int, block_quota: int = 2, - num_warps_per_block: int = 32, + num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_all_layer_mla( src_layers, @@ -196,7 +203,7 @@ def transfer_kv_all_layer_mla_lf_pf( dst_layout_dim: int, num_layers: int, block_quota: int = 2, - num_warps_per_block: int = 32, + num_warps_per_block: int = 16 if _is_hip else 32, ): torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf( src_layers, diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index 2105c7c1f..6e3466ec3 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -49,6 +49,7 @@ sources = [ "csrc/moe/moe_align_kernel.cu", "csrc/moe/moe_topk_softmax_kernels.cu", "csrc/speculative/eagle_utils.cu", + "csrc/kvcacheio/transfer.cu", ] cxx_flags = ["-O3"] diff --git a/test/srt/hicache/test_hicache.py b/test/srt/hicache/test_hicache.py index 3fee235ad..f7616d098 100644 --- a/test/srt/hicache/test_hicache.py +++ b/test/srt/hicache/test_hicache.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_process_tree +from sglang.srt.utils import is_hip, kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -11,6 +11,8 @@ from sglang.test.test_utils import ( popen_launch_server, ) +_is_hip = is_hip() + class TestHiCache(CustomTestCase): @classmethod @@ -26,7 +28,7 @@ class TestHiCache(CustomTestCase): "--mem-fraction-static", 0.7, "--hicache-size", - 100, + 100 if not _is_hip else 200, ], ) diff --git a/test/srt/hicache/test_hicache_mla.py b/test/srt/hicache/test_hicache_mla.py index 5d306453c..c5db0f74a 100644 --- a/test/srt/hicache/test_hicache_mla.py +++ b/test/srt/hicache/test_hicache_mla.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_process_tree +from sglang.srt.utils import is_hip, kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MLA_MODEL_NAME_FOR_TEST, @@ -11,6 +11,12 @@ from sglang.test.test_utils import ( popen_launch_server, ) +_is_hip = is_hip() +if _is_hip: + hicache_args = ["--hicache-size", 200] +else: + hicache_args = ["--hicache-ratio", 2] + class TestHierarchicalMLA(CustomTestCase): @classmethod @@ -24,9 +30,8 @@ class TestHierarchicalMLA(CustomTestCase): other_args=[ "--trust-remote-code", "--enable-hierarchical-cache", - "--hicache-ratio", - 2, - ], + ] + + hicache_args, ) @classmethod diff --git a/test/srt/hicache/test_hicache_storage.py b/test/srt/hicache/test_hicache_storage.py index aadc9529d..7bc947b8c 100644 --- a/test/srt/hicache/test_hicache_storage.py +++ b/test/srt/hicache/test_hicache_storage.py @@ -1,7 +1,7 @@ import unittest from types import SimpleNamespace -from sglang.srt.utils import kill_process_tree +from sglang.srt.utils import is_hip, kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -11,6 +11,8 @@ from sglang.test.test_utils import ( popen_launch_server, ) +_is_hip = is_hip() + class TestHiCache(CustomTestCase): @classmethod @@ -26,7 +28,7 @@ class TestHiCache(CustomTestCase): "--mem-fraction-static", 0.7, "--hicache-size", - 100, + 100 if not _is_hip else 200, "--page-size", "64", "--hicache-storage-backend", diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 003942e65..2b1ef4c53 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -162,6 +162,9 @@ suites = { # Add AMD tests suite_amd = { "per-commit-amd": [ + TestFile("hicache/test_hicache.py", 116), + TestFile("hicache/test_hicache_mla.py", 127), + TestFile("hicache/test_hicache_storage.py", 127), TestFile("lora/test_lora.py", 200), TestFile("lora/test_lora_eviction.py", 200), TestFile("lora/test_lora_backend.py", 99),