[AMD] Support Hierarchical Caching on AMD GPUs (#8236)

This commit is contained in:
Hubert Lu
2025-08-28 15:27:07 -07:00
committed by GitHub
parent 5343058875
commit 711390a971
10 changed files with 105 additions and 32 deletions

View File

@@ -3,6 +3,13 @@ from typing import List
import torch
def is_hip() -> bool:
return torch.version.hip is not None
_is_hip = is_hip()
def transfer_kv_per_layer(
src_k: torch.Tensor,
dst_k: torch.Tensor,
@@ -12,7 +19,7 @@ def transfer_kv_per_layer(
dst_indices: torch.Tensor,
item_size: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer(
src_k,
@@ -38,7 +45,7 @@ def transfer_kv_per_layer_pf_lf(
item_size: int,
src_layout_dim: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf(
src_k,
@@ -65,7 +72,7 @@ def transfer_kv_all_layer(
item_size: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer(
src_k_layers,
@@ -92,7 +99,7 @@ def transfer_kv_all_layer_lf_pf(
dst_layout_dim: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf(
src_k_layers,
@@ -128,7 +135,7 @@ def transfer_kv_per_layer_mla(
dst_indices: torch.Tensor,
item_size: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
src,
@@ -150,7 +157,7 @@ def transfer_kv_per_layer_mla_pf_lf(
item_size: int,
src_layout_dim: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf(
src,
@@ -173,7 +180,7 @@ def transfer_kv_all_layer_mla(
item_size: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
src_layers,
@@ -196,7 +203,7 @@ def transfer_kv_all_layer_mla_lf_pf(
dst_layout_dim: int,
num_layers: int,
block_quota: int = 2,
num_warps_per_block: int = 32,
num_warps_per_block: int = 16 if _is_hip else 32,
):
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf(
src_layers,