[AMD] Support Hierarchical Caching on AMD GPUs (#8236)
This commit is contained in:
@@ -3,6 +3,13 @@ from typing import List
|
||||
import torch
|
||||
|
||||
|
||||
def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
def transfer_kv_per_layer(
|
||||
src_k: torch.Tensor,
|
||||
dst_k: torch.Tensor,
|
||||
@@ -12,7 +19,7 @@ def transfer_kv_per_layer(
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 32,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer(
|
||||
src_k,
|
||||
@@ -38,7 +45,7 @@ def transfer_kv_per_layer_pf_lf(
|
||||
item_size: int,
|
||||
src_layout_dim: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 32,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_pf_lf(
|
||||
src_k,
|
||||
@@ -65,7 +72,7 @@ def transfer_kv_all_layer(
|
||||
item_size: int,
|
||||
num_layers: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 32,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer(
|
||||
src_k_layers,
|
||||
@@ -92,7 +99,7 @@ def transfer_kv_all_layer_lf_pf(
|
||||
dst_layout_dim: int,
|
||||
num_layers: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 32,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_lf_pf(
|
||||
src_k_layers,
|
||||
@@ -128,7 +135,7 @@ def transfer_kv_per_layer_mla(
|
||||
dst_indices: torch.Tensor,
|
||||
item_size: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 32,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_mla(
|
||||
src,
|
||||
@@ -150,7 +157,7 @@ def transfer_kv_per_layer_mla_pf_lf(
|
||||
item_size: int,
|
||||
src_layout_dim: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 32,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_per_layer_mla_pf_lf(
|
||||
src,
|
||||
@@ -173,7 +180,7 @@ def transfer_kv_all_layer_mla(
|
||||
item_size: int,
|
||||
num_layers: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 32,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_mla(
|
||||
src_layers,
|
||||
@@ -196,7 +203,7 @@ def transfer_kv_all_layer_mla_lf_pf(
|
||||
dst_layout_dim: int,
|
||||
num_layers: int,
|
||||
block_quota: int = 2,
|
||||
num_warps_per_block: int = 32,
|
||||
num_warps_per_block: int = 16 if _is_hip else 32,
|
||||
):
|
||||
torch.ops.sgl_kernel.transfer_kv_all_layer_mla_lf_pf(
|
||||
src_layers,
|
||||
|
||||
Reference in New Issue
Block a user