185 lines
5.8 KiB
Python
185 lines
5.8 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import torch
|
|
|
|
from vllm.platforms import current_platform
|
|
from vllm.triton_utils import tl, triton
|
|
|
|
|
|
@triton.jit
|
|
def reshape_and_cache_kernel_flash(
|
|
key_ptr, # [num_tokens, num_heads, head_size]
|
|
value_ptr, # [num_tokens, num_heads, head_size]
|
|
key_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
|
value_cache_ptr, # [num_blocks, block_size, num_heads, head_size]
|
|
slot_mapping_ptr, # [num_tokens]
|
|
k_scale, # float32
|
|
v_scale, # float32
|
|
# strides
|
|
key_stride: tl.int64,
|
|
value_stride: tl.int64,
|
|
block_stride: tl.int64,
|
|
page_stride: tl.int64,
|
|
num_heads: tl.constexpr,
|
|
head_size: tl.constexpr,
|
|
block_size: tl.constexpr,
|
|
# FP8 flags
|
|
FP8_KV_CACHE: tl.constexpr,
|
|
# tune parameters
|
|
TILE_SIZE: tl.constexpr,
|
|
):
|
|
token_idx = tl.program_id(axis=0)
|
|
slot_idx = tl.load(slot_mapping_ptr + token_idx).to(tl.int64)
|
|
if slot_idx < 0:
|
|
# Padding token that should be ignored.
|
|
return
|
|
|
|
tile_i = tl.program_id(axis=1)
|
|
tile_offs = tl.arange(0, TILE_SIZE)
|
|
tile_pos = tile_i * TILE_SIZE + tile_offs
|
|
|
|
block_idx = slot_idx // block_size
|
|
block_offset = slot_idx % block_size
|
|
|
|
src_key_idx = token_idx * key_stride
|
|
src_value_idx = token_idx * value_stride
|
|
|
|
tgt_idx = block_idx * block_stride + block_offset * page_stride
|
|
|
|
# [TILE_SIZE]
|
|
key_load = tl.load(
|
|
key_ptr + src_key_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
|
|
)
|
|
if FP8_KV_CACHE:
|
|
# tl.store will do the correct implicit cast to fp8,
|
|
# based on the key_cache_ptr.dtype.element_ty
|
|
key_tile = key_load if key_load.dtype.is_fp8() else key_load / tl.load(k_scale)
|
|
else:
|
|
key_tile = key_load
|
|
|
|
# [TILE_SIZE]
|
|
value_load = tl.load(
|
|
value_ptr + src_value_idx + tile_pos, mask=tile_pos < (num_heads * head_size)
|
|
)
|
|
if FP8_KV_CACHE:
|
|
if value_load.dtype.is_fp8():
|
|
value_tile = value_load
|
|
else:
|
|
# tl.store will do the correct implicit cast to fp8,
|
|
# based on the value_cache_ptr.dtype.element_ty
|
|
value_tile = value_load / tl.load(v_scale)
|
|
else:
|
|
value_tile = value_load
|
|
|
|
tl.store(
|
|
key_cache_ptr + tgt_idx + tile_pos,
|
|
key_tile,
|
|
mask=tile_pos < (num_heads * head_size),
|
|
)
|
|
tl.store(
|
|
value_cache_ptr + tgt_idx + tile_pos,
|
|
value_tile,
|
|
mask=tile_pos < (num_heads * head_size),
|
|
)
|
|
return
|
|
|
|
|
|
def triton_reshape_and_cache_flash(
|
|
key: torch.Tensor, # [num_tokens, num_heads, head_size]
|
|
value: torch.Tensor, # [num_tokens, num_heads, head_size]
|
|
# [num_blocks, block_size, num_heads, head_size]
|
|
key_cache: torch.Tensor,
|
|
# [num_blocks, block_size, num_heads, head_size]
|
|
value_cache: torch.Tensor,
|
|
slot_mapping: torch.Tensor, # [num_tokens]
|
|
kv_cache_dtype: str, # "auto", "fp8"
|
|
k_scale: torch.Tensor, # float32
|
|
v_scale: torch.Tensor, # float32
|
|
):
|
|
num_heads = key.shape[1]
|
|
head_size = key.shape[2]
|
|
block_size = key_cache.shape[1]
|
|
n = num_heads * head_size
|
|
|
|
key_stride = key.stride()[0]
|
|
value_stride = value.stride()[0]
|
|
block_stride = key_cache.stride()[0]
|
|
page_stride = key_cache.stride()[1]
|
|
|
|
head_stride = key_cache.stride()[2]
|
|
assert head_stride == head_size, "only continous heads are supported"
|
|
|
|
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), (
|
|
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
|
|
)
|
|
kv_cache_torch_dtype = (
|
|
current_platform.fp8_dtype()
|
|
if kv_cache_dtype.startswith("fp8")
|
|
else key_cache.dtype
|
|
)
|
|
|
|
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"):
|
|
# to avoid erounous implicit cast in triton kernel (tl.store to uint8)
|
|
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
|
|
key_cache = key_cache.view(kv_cache_torch_dtype)
|
|
value_cache = value_cache.view(kv_cache_torch_dtype)
|
|
assert kv_cache_dtype != torch.uint8, (
|
|
"explicit fp8 cast and store to "
|
|
"uint8 is not supported by triton reshape_and_cache_flash"
|
|
)
|
|
|
|
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8")
|
|
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
|
|
torch.float8_e4m3fn,
|
|
torch.float8_e5m2,
|
|
torch.uint8,
|
|
torch.float8_e4m3fnuz,
|
|
], (
|
|
"unsupported dtype of KV cache tensor, got "
|
|
"{kv_cache_torch_dtype}. Supported kv cache dtypes: fp8e4m3fn, "
|
|
"fp8e5m2, uint8, bfloat16, float16, float32, fp8e4m3fnuz."
|
|
)
|
|
|
|
# heuristics instead of autotuning
|
|
TILE_SIZE = min(2048, triton.next_power_of_2(n))
|
|
if current_platform.is_rocm() or current_platform.is_xpu():
|
|
num_stages = 4
|
|
num_warps = 8
|
|
else: # cuda
|
|
num_stages = 10
|
|
num_warps = 16
|
|
if torch.cuda.get_device_capability(key.device)[0] < 9:
|
|
TILE_SIZE = min(512, TILE_SIZE)
|
|
|
|
# TODO(ngl): maybe replace with static launch grid to avoid overhead if
|
|
# using cudagraphs
|
|
grid = lambda meta: (
|
|
slot_mapping.shape[0],
|
|
triton.cdiv(n, meta["TILE_SIZE"]),
|
|
)
|
|
|
|
reshape_and_cache_kernel_flash[grid](
|
|
key_ptr=key,
|
|
value_ptr=value,
|
|
key_cache_ptr=key_cache,
|
|
value_cache_ptr=value_cache,
|
|
slot_mapping_ptr=slot_mapping,
|
|
k_scale=k_scale,
|
|
v_scale=v_scale,
|
|
# strides
|
|
key_stride=key_stride,
|
|
value_stride=value_stride,
|
|
block_stride=block_stride,
|
|
page_stride=page_stride,
|
|
num_heads=num_heads,
|
|
head_size=head_size,
|
|
block_size=block_size,
|
|
# FP8 flags
|
|
FP8_KV_CACHE=FP8_KV_CACHE,
|
|
# autotune parameters
|
|
TILE_SIZE=TILE_SIZE,
|
|
num_warps=num_warps,
|
|
num_stages=num_stages,
|
|
)
|