Files
sglang/python/sglang/srt/mem_cache/memory_pool.py

1716 lines
59 KiB
Python

"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from __future__ import annotations
from dataclasses import dataclass
from sglang.srt.configs.mamba_utils import Mamba2CacheParams
from sglang.srt.layers.attention.nsa import index_buf_accessor
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
"""
Memory pool.
SGLang has two levels of memory pool.
ReqToTokenPool maps a request to its token locations.
TokenToKVPoolAllocator manages the indices to kv cache data.
KVCache actually holds the physical kv cache.
"""
import abc
import logging
from contextlib import nullcontext
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import triton
import triton.language as tl
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
if TYPE_CHECKING:
from sglang.srt.managers.cache_controller import LayerDoneCounter
logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024
_is_cuda = is_cuda()
_is_npu = is_npu()
if _is_npu:
import torch_npu
def get_tensor_size_bytes(t: torch.Tensor):
return np.prod(t.shape) * t.dtype.itemsize
class ReqToTokenPool:
"""A memory pool that maps a request to its token locations."""
def __init__(
self,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
):
memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
self.size = size
self.max_context_len = max_context_len
self.device = device
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
self.req_to_token = torch.zeros(
(size, max_context_len), dtype=torch.int32, device=device
)
self.free_slots = list(range(size))
def write(self, indices, values):
self.req_to_token[indices] = values
def available_size(self):
return len(self.free_slots)
def alloc(self, need_size: int) -> List[int]:
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index: Union[int, List[int]]):
if isinstance(free_index, (int,)):
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
def clear(self):
self.free_slots = list(range(self.size))
class MambaPool:
@dataclass(frozen=True, kw_only=True)
class State:
conv: torch.Tensor
temporal: torch.Tensor
def at_layer_idx(self, layer: int):
return type(self)(**{k: v[layer] for k, v in vars(self).items()})
def mem_usage_bytes(self):
return sum(get_tensor_size_bytes(t) for t in vars(self).values())
@dataclass(frozen=True, kw_only=True)
class SpeculativeState(State):
intermediate_ssm: torch.Tensor
intermediate_conv_window: torch.Tensor
def __init__(
self,
*,
size: int,
cache_params: "Mamba2CacheParams",
device: str,
speculative_num_draft_tokens: Optional[int] = None,
):
conv_state_shape = cache_params.shape.conv
temporal_state_shape = cache_params.shape.temporal
conv_dtype = cache_params.dtype.conv
ssm_dtype = cache_params.dtype.temporal
num_mamba_layers = len(cache_params.layers)
# assume conv_state = (dim, state_len)
assert conv_state_shape[0] > conv_state_shape[1]
conv_state = torch.zeros(
size=(num_mamba_layers, size + 1) + conv_state_shape,
dtype=conv_dtype,
device=device,
)
temporal_state = torch.zeros(
size=(num_mamba_layers, size + 1) + temporal_state_shape,
dtype=ssm_dtype,
device=device,
)
if speculative_num_draft_tokens is not None:
# Cache intermediate SSM states per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
intermediate_ssm_state_cache = torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
temporal_state_shape[0],
temporal_state_shape[1],
temporal_state_shape[2],
),
dtype=ssm_dtype,
device="cuda",
)
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
intermediate_conv_window_cache = torch.zeros(
size=(
num_mamba_layers,
size + 1,
speculative_num_draft_tokens,
conv_state_shape[0],
conv_state_shape[1],
),
dtype=conv_dtype,
device="cuda",
)
self.mamba_cache = self.SpeculativeState(
conv=conv_state,
temporal=temporal_state,
intermediate_ssm=intermediate_ssm_state_cache,
intermediate_conv_window=intermediate_conv_window_cache,
)
logger.info(
f"Mamba Cache is allocated. "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
)
else:
self.mamba_cache = self.State(conv=conv_state, temporal=temporal_state)
logger.info(
f"Mamba Cache is allocated. "
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
)
self.size = size
self.free_slots = list(range(size))
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
def get_speculative_mamba2_params_all_layers(self) -> SpeculativeState:
assert isinstance(self.mamba_cache, self.SpeculativeState)
return self.mamba_cache
def mamba2_layer_cache(self, layer_id: int):
return self.mamba_cache.at_layer_idx(layer_id)
def available_size(self):
return len(self.free_slots)
def alloc(self, need_size: int) -> Optional[List[int]]:
if need_size > len(self.free_slots):
return None
select_index = self.free_slots[:need_size]
self.free_slots = self.free_slots[need_size:]
return select_index
def free(self, free_index: Union[int, List[int]]):
if isinstance(free_index, (int,)):
self.free_slots.append(free_index)
else:
self.free_slots.extend(free_index)
self.mamba_cache.conv[:, free_index] = self.mamba_cache.temporal[
:, free_index
] = 0
def clear(self):
self.free_slots = list(range(self.size))
class HybridReqToTokenPool(ReqToTokenPool):
"""A memory pool that maps a request to its token locations."""
def __init__(
self,
*,
size: int,
max_context_len: int,
device: str,
enable_memory_saver: bool,
cache_params: "Mamba2CacheParams",
speculative_num_draft_tokens: int = None,
):
super().__init__(
size=size,
max_context_len=max_context_len,
device=device,
enable_memory_saver=enable_memory_saver,
)
self.mamba_pool = MambaPool(
size=size,
cache_params=cache_params,
device=device,
speculative_num_draft_tokens=speculative_num_draft_tokens,
)
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
self.device = device
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
size, dtype=torch.int32, device=self.device
)
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
# For chunk prefill req, we do not need to allocate mamba cache,
# We could use allocated mamba cache instead.
def alloc(
self, need_size: int, reqs: Optional[List["Req"]] = None
) -> Optional[List[int]]:
select_index = super().alloc(need_size)
if select_index == None:
return None
mamba_index = []
for req in reqs:
rid = req.rid
if rid in self.rid_to_mamba_index_mapping:
mid = self.rid_to_mamba_index_mapping[rid]
elif (mid := self.mamba_pool.alloc(1)) is not None:
mid = mid[0]
self.rid_to_mamba_index_mapping[rid] = mid
self.mamba_index_to_rid_mapping[mid] = rid
mamba_index.append(mid)
assert len(select_index) == len(
mamba_index
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
mamba_index, dtype=torch.int32, device=self.device
)
return select_index
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
return self.req_index_to_mamba_index_mapping[req_indices]
def mamba2_layer_cache(self, layer_id: int):
assert layer_id in self.mamba_map
return self.mamba_pool.mamba2_layer_cache(self.mamba_map[layer_id])
def get_speculative_mamba2_params_all_layers(self) -> MambaPool.SpeculativeState:
return self.mamba_pool.get_speculative_mamba2_params_all_layers()
# For chunk prefill, we can not free mamba cache, we need use it in the future
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
super().free(free_index)
if free_mamba_cache:
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
mamba_index_list = mamba_index.tolist()
if isinstance(mamba_index_list, int):
mamba_index_list = [mamba_index_list]
self.mamba_pool.free(mamba_index_list)
for mid in mamba_index_list:
rid = self.mamba_index_to_rid_mapping[mid]
self.mamba_index_to_rid_mapping.pop(mid)
self.rid_to_mamba_index_mapping.pop(rid)
def clear(self):
super().clear()
self.mamba_pool.clear()
class KVCache(abc.ABC):
@abc.abstractmethod
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
layer_num: int,
device: str,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
self.size = size
self.page_size = page_size
self.dtype = dtype
self.device = device
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
self.store_dtype = torch.uint8
else:
self.store_dtype = dtype
self.layer_num = layer_num
self.start_layer = start_layer or 0
self.end_layer = end_layer or layer_num - 1
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
self.mem_usage = 0
# used for chunked cpu-offloading
self.cpu_offloading_chunk_size = 8192
# default state for optional layer-wise transfer control
self.layer_transfer_counter = None
def _finalize_allocation_log(self, num_tokens: int):
"""Common logging and mem_usage computation for KV cache allocation.
Supports both tuple (K, V) size returns and single KV size returns.
"""
kv_size_bytes = self.get_kv_size_bytes()
if isinstance(kv_size_bytes, tuple):
k_size, v_size = kv_size_bytes
k_size_GB = k_size / GB
v_size_GB = v_size / GB
logger.info(
f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
)
self.mem_usage = k_size_GB + v_size_GB
else:
kv_size_GB = kv_size_bytes / GB
logger.info(
f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
)
self.mem_usage = kv_size_GB
@abc.abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
raise NotImplementedError()
@abc.abstractmethod
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError()
@abc.abstractmethod
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
) -> None:
raise NotImplementedError()
def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
self.layer_transfer_counter = layer_transfer_counter
def get_cpu_copy(self, indices):
raise NotImplementedError()
def load_cpu_copy(self, kv_cache_cpu, indices):
raise NotImplementedError()
class MHATokenToKVPool(KVCache):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: str,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
enable_kv_cache_copy: bool = False,
):
super().__init__(
size,
page_size,
dtype,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
)
self.head_num = head_num
self.head_dim = head_dim
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
# TODO(shangming): abstract custom allocator class for more backends
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
self._create_buffers()
self.device_module = torch.get_device_module(self.device)
self.alt_stream = self.device_module.Stream() if _is_cuda else None
if enable_kv_cache_copy:
self._init_kv_copy_and_warmup()
else:
self._kv_copy_config = None
self._finalize_allocation_log(size)
def _init_kv_copy_and_warmup(self):
# Heuristics for KV copy tiling
_KV_COPY_STRIDE_THRESHOLD_LARGE = 8192
_KV_COPY_STRIDE_THRESHOLD_MEDIUM = 4096
_KV_COPY_TILE_SIZE_LARGE = 512
_KV_COPY_TILE_SIZE_MEDIUM = 256
_KV_COPY_TILE_SIZE_SMALL = 128
_KV_COPY_NUM_WARPS_LARGE_TILE = 8
_KV_COPY_NUM_WARPS_SMALL_TILE = 4
stride_bytes = int(self.data_strides[0].item())
if stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_LARGE:
bytes_per_tile = _KV_COPY_TILE_SIZE_LARGE
elif stride_bytes >= _KV_COPY_STRIDE_THRESHOLD_MEDIUM:
bytes_per_tile = _KV_COPY_TILE_SIZE_MEDIUM
else:
bytes_per_tile = _KV_COPY_TILE_SIZE_SMALL
self._kv_copy_config = {
"bytes_per_tile": bytes_per_tile,
"byte_tiles": (stride_bytes + bytes_per_tile - 1) // bytes_per_tile,
"num_warps": (
_KV_COPY_NUM_WARPS_SMALL_TILE
if bytes_per_tile <= _KV_COPY_TILE_SIZE_MEDIUM
else _KV_COPY_NUM_WARPS_LARGE_TILE
),
}
dummy_loc = torch.zeros(1, dtype=torch.int32, device=self.device)
grid = (self.data_ptrs.numel(), self._kv_copy_config["byte_tiles"])
copy_all_layer_kv_cache_tiled[grid](
self.data_ptrs,
self.data_strides,
dummy_loc,
dummy_loc,
1,
1,
BYTES_PER_TILE=self._kv_copy_config["bytes_per_tile"],
num_warps=self._kv_copy_config["num_warps"],
num_stages=2,
)
def _create_buffers(self):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.v_buffer = [
torch.zeros(
(self.size + self.page_size, self.head_num, self.head_dim),
dtype=self.store_dtype,
device=self.device,
)
for _ in range(self.layer_num)
]
self.k_data_ptrs = torch.tensor(
[x.data_ptr() for x in self.k_buffer],
dtype=torch.uint64,
device=self.device,
)
self.v_data_ptrs = torch.tensor(
[x.data_ptr() for x in self.v_buffer],
dtype=torch.uint64,
device=self.device,
)
self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
self.data_strides = torch.tensor(
[
np.prod(x.shape[1:]) * x.dtype.itemsize
for x in self.k_buffer + self.v_buffer
],
device=self.device,
)
def _clear_buffers(self):
del self.k_buffer
del self.v_buffer
def get_kv_size_bytes(self):
assert hasattr(self, "k_buffer")
assert hasattr(self, "v_buffer")
k_size_bytes = 0
for k_cache in self.k_buffer:
k_size_bytes += get_tensor_size_bytes(k_cache)
v_size_bytes = 0
for v_cache in self.v_buffer:
v_size_bytes += get_tensor_size_bytes(v_cache)
return k_size_bytes, v_size_bytes
# for disagg
def get_contiguous_buf_infos(self):
# layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs = [
self._get_key_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self._get_value_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
kv_data_lens = [
self._get_key_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self._get_value_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
kv_item_lens = [
self._get_key_buffer(i)[0].nbytes * self.page_size
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self._get_value_buffer(i)[0].nbytes * self.page_size
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
chunk_size = self.cpu_offloading_chunk_size
for layer_id in range(self.layer_num):
kv_cache_cpu.append([])
for i in range(0, len(indices), chunk_size):
chunk_indices = indices[i : i + chunk_size]
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True
)
v_cpu = self.v_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True
)
kv_cache_cpu[-1].append([k_cpu, v_cpu])
torch.cuda.synchronize()
return kv_cache_cpu
def load_cpu_copy(self, kv_cache_cpu, indices):
torch.cuda.synchronize()
chunk_size = self.cpu_offloading_chunk_size
for layer_id in range(self.layer_num):
for i in range(0, len(indices), chunk_size):
chunk_indices = indices[i : i + chunk_size]
k_cpu, v_cpu = (
kv_cache_cpu[layer_id][i // chunk_size][0],
kv_cache_cpu[layer_id][i // chunk_size][1],
)
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
self.k_buffer[layer_id][chunk_indices] = k_chunk
self.v_buffer[layer_id][chunk_indices] = v_chunk
torch.cuda.synchronize()
def _get_key_buffer(self, layer_id: int):
# for internal use of referencing
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.k_buffer[layer_id - self.start_layer]
def get_key_buffer(self, layer_id: int):
# note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
# it is supposed to be used only by attention backend not for information purpose
# same applies to get_value_buffer and get_kv_buffer
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self._get_key_buffer(layer_id)
def _get_value_buffer(self, layer_id: int):
# for internal use of referencing
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self._get_value_buffer(layer_id)
def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
layer_id_override: Optional[int] = None,
):
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
if layer_id_override is not None:
layer_id = layer_id_override
else:
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
if k_scale is not None:
cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
if get_is_capture_mode() and self.alt_stream is not None:
# Overlap the copy of K and V cache for small batch size
current_stream = self.device_module.current_stream()
self.alt_stream.wait_stream(current_stream)
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
with self.device_module.stream(self.alt_stream):
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
current_stream.wait_stream(self.alt_stream)
else:
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
N = tgt_loc.numel()
if N == 0:
return
assert (
self._kv_copy_config is not None
), "KV copy not initialized. Set enable_kv_cache_copy=True in __init__"
cfg = self._kv_copy_config
N_upper = next_power_of_2(N)
grid = (self.data_ptrs.numel(), cfg["byte_tiles"])
copy_all_layer_kv_cache_tiled[grid](
self.data_ptrs,
self.data_strides,
tgt_loc,
src_loc,
N,
N_upper,
BYTES_PER_TILE=cfg["bytes_per_tile"],
num_warps=cfg["num_warps"],
num_stages=2,
)
class HybridLinearKVPool(KVCache):
"""KV cache with separate pools for full and linear attention layers."""
def __init__(
self,
size: int,
dtype: torch.dtype,
page_size: int,
head_num: int,
head_dim: int,
full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool,
device: str,
):
self.size = size
self.dtype = dtype
self.device = device
self.full_layer_nums = len(full_attention_layer_ids)
self.page_size = page_size
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose
if _is_npu:
TokenToKVPoolClass = AscendTokenToKVPool
else:
TokenToKVPoolClass = MHATokenToKVPool
self.full_kv_pool = TokenToKVPoolClass(
size=size,
page_size=self.page_size,
dtype=dtype,
head_num=head_num,
head_dim=head_dim,
layer_num=self.full_layer_nums,
device=device,
enable_memory_saver=False,
)
self.full_attention_layer_id_mapping = {
id: i for i, id in enumerate(full_attention_layer_ids)
}
k_size, v_size = self.get_kv_size_bytes()
self.mem_usage = (k_size + v_size) / GB
def get_kv_size_bytes(self):
return self.full_kv_pool.get_kv_size_bytes()
def get_contiguous_buf_infos(self):
return self.full_kv_pool.get_contiguous_buf_infos()
def _transfer_full_attention_id(self, layer_id: int):
if layer_id not in self.full_attention_layer_id_mapping:
raise ValueError(
f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
)
return self.full_attention_layer_id_mapping[layer_id]
def get_key_buffer(self, layer_id: int):
layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_key_buffer(layer_id)
def get_value_buffer(self, layer_id: int):
layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_value_buffer(layer_id)
def get_kv_buffer(self, layer_id: int):
layer_id = self._transfer_full_attention_id(layer_id)
return self.full_kv_pool.get_kv_buffer(layer_id)
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
layer_id = self._transfer_full_attention_id(layer.layer_id)
self.full_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id,
)
def get_v_head_dim(self):
return self.full_kv_pool.get_value_buffer(0).shape[-1]
class SWAKVPool(KVCache):
"""KV cache with separate pools for full and SWA attention layers."""
def __init__(
self,
size: int,
size_swa: int,
dtype: torch.dtype,
swa_attention_layer_ids: List[int],
full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool,
token_to_kv_pool_class: KVCache = MHATokenToKVPool,
**kwargs,
):
self.size = size
self.size_swa = size_swa
self.dtype = dtype
self.swa_layer_nums = len(swa_attention_layer_ids)
self.full_layer_nums = len(full_attention_layer_ids)
kwargs["page_size"] = 1
kwargs["enable_memory_saver"] = False
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
assert not enable_kvcache_transpose
self.swa_kv_pool = token_to_kv_pool_class(
size=size_swa,
dtype=dtype,
layer_num=self.swa_layer_nums,
**kwargs,
)
self.full_kv_pool = token_to_kv_pool_class(
size=size,
dtype=dtype,
layer_num=self.full_layer_nums,
**kwargs,
)
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
for swa_layer_id, global_layer_id in enumerate(swa_attention_layer_ids):
self.layers_mapping[global_layer_id] = (swa_layer_id, True)
self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
k_size, v_size = self.get_kv_size_bytes()
self.mem_usage = (k_size + v_size) / GB
def get_kv_size_bytes(self):
k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
k_size_swa, v_size_swa = self.swa_kv_pool.get_kv_size_bytes()
return k_size + k_size_swa, v_size + v_size_swa
def get_contiguous_buf_infos(self):
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
self.full_kv_pool.get_contiguous_buf_infos()
)
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
self.swa_kv_pool.get_contiguous_buf_infos()
)
kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
kv_data_lens = full_kv_data_lens + swa_kv_data_lens
kv_item_lens = full_kv_item_lens + swa_kv_item_lens
return kv_data_ptrs, kv_data_lens, kv_item_lens
def get_key_buffer(self, layer_id: int):
layer_id_pool, is_swa = self.layers_mapping[layer_id]
if is_swa:
return self.swa_kv_pool.get_key_buffer(layer_id_pool)
else:
return self.full_kv_pool.get_key_buffer(layer_id_pool)
def get_value_buffer(self, layer_id: int):
layer_id_pool, is_swa = self.layers_mapping[layer_id]
if is_swa:
return self.swa_kv_pool.get_value_buffer(layer_id_pool)
else:
return self.full_kv_pool.get_value_buffer(layer_id_pool)
def get_kv_buffer(self, layer_id: int):
layer_id_pool, is_swa = self.layers_mapping[layer_id]
if is_swa:
return self.swa_kv_pool.get_kv_buffer(layer_id_pool)
else:
return self.full_kv_pool.get_kv_buffer(layer_id_pool)
def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
assert self.full_to_swa_index_mapping is not None
return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: float = 1.0,
v_scale: float = 1.0,
):
layer_id = layer.layer_id
layer_id_pool, is_swa = self.layers_mapping[layer_id]
if is_swa:
if self.full_to_swa_index_mapping is not None:
loc = self.translate_loc_from_full_to_swa(loc)
self.swa_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id_pool,
)
else:
self.full_kv_pool.set_kv_buffer(
None,
loc,
cache_k,
cache_v,
k_scale,
v_scale,
layer_id_override=layer_id_pool,
)
class AscendTokenToKVPool(MHATokenToKVPool):
def _create_buffers(self):
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
# Continuous memory improves the efficiency of Ascend`s transmission backend,
# while other backends remain unchanged.
self.kv_buffer = torch.zeros(
(
2,
self.layer_num,
self.size // self.page_size + 1,
self.page_size,
self.head_num,
self.head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
self.k_buffer = self.kv_buffer[0]
self.v_buffer = self.kv_buffer[1]
# for disagg
def get_contiguous_buf_infos(self):
# layer_num x [seq_len, head_num, head_dim]
# layer_num x [page_num, page_size, head_num, head_dim]
kv_data_ptrs = [
self.get_key_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i).data_ptr()
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
kv_data_lens = [
self.get_key_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i).nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
kv_item_lens = [
self.get_key_buffer(i)[0].nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
] + [
self.get_value_buffer(i)[0].nbytes
for i in range(self.start_layer, self.start_layer + self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
k_scale: Optional[float] = None,
v_scale: Optional[float] = None,
layer_id_override: Optional[int] = None,
):
if layer_id_override is not None:
layer_id = layer_id_override
else:
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
if k_scale is not None:
cache_k.div_(k_scale)
if v_scale is not None:
cache_v.div_(v_scale)
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
torch_npu._npu_reshape_and_cache(
key=cache_k,
value=cache_v,
key_cache=self.k_buffer[layer_id].view(
-1, self.page_size, self.head_num, self.head_dim
),
value_cache=self.v_buffer[layer_id].view(
-1, self.page_size, self.head_num, self.head_dim
),
slot_indices=loc,
)
@triton.jit
def set_mla_kv_buffer_kernel(
kv_buffer_ptr,
cache_k_nope_ptr,
cache_k_rope_ptr,
loc_ptr,
buffer_stride: tl.constexpr,
nope_stride: tl.constexpr,
rope_stride: tl.constexpr,
nope_dim: tl.constexpr,
rope_dim: tl.constexpr,
BLOCK: tl.constexpr,
):
pid_loc = tl.program_id(0)
pid_blk = tl.program_id(1)
base = pid_blk * BLOCK
offs = base + tl.arange(0, BLOCK)
total_dim = nope_dim + rope_dim
mask = offs < total_dim
loc = tl.load(loc_ptr + pid_loc)
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
if base + BLOCK <= nope_dim:
src = tl.load(
cache_k_nope_ptr + pid_loc * nope_stride + offs,
mask=mask,
)
else:
offs_rope = offs - nope_dim
src = tl.load(
cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
mask=mask,
)
tl.store(dst_ptr, src, mask=mask)
def set_mla_kv_buffer_triton(
kv_buffer: torch.Tensor,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
nope_dim = cache_k_nope.shape[-1]
rope_dim = cache_k_rope.shape[-1]
total_dim = nope_dim + rope_dim
BLOCK = 128
n_loc = loc.numel()
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
set_mla_kv_buffer_kernel[grid](
kv_buffer,
cache_k_nope,
cache_k_rope,
loc,
kv_buffer.stride(0),
cache_k_nope.stride(0),
cache_k_rope.stride(0),
nope_dim,
rope_dim,
BLOCK=BLOCK,
)
class MLATokenToKVPool(KVCache):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
kv_lora_rank: int,
qk_rope_head_dim: int,
layer_num: int,
device: str,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
use_nsa: bool = False,
override_kv_cache_dim: Optional[int] = None,
):
super().__init__(
size,
page_size,
dtype,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
)
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.use_nsa = use_nsa
self.nsa_kv_cache_store_fp8 = use_nsa and dtype == torch.float8_e4m3fn
# TODO do not hardcode
self.kv_cache_dim = (
656
if self.use_nsa and self.nsa_kv_cache_store_fp8
else (kv_lora_rank + qk_rope_head_dim)
)
# for disagg with nvlink
self.enable_custom_mem_pool = get_bool_env_var(
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
)
if self.enable_custom_mem_pool:
# TODO(shangming): abstract custom allocator class for more backends
from mooncake.allocator import NVLinkAllocator
allocator = NVLinkAllocator.get_allocator(self.device)
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
else nullcontext()
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [
torch.zeros(
(size + page_size, 1, self.kv_cache_dim),
dtype=self.store_dtype,
device=device,
)
for _ in range(layer_num)
]
self.data_ptrs = torch.tensor(
[x.data_ptr() for x in self.kv_buffer],
dtype=torch.uint64,
device=self.device,
)
if not use_nsa:
# NSA will allocate indexer KV cache later and then log the total size
self._finalize_allocation_log(size)
def get_kv_size_bytes(self):
assert hasattr(self, "kv_buffer")
kv_size_bytes = 0
for kv_cache in self.kv_buffer:
kv_size_bytes += get_tensor_size_bytes(kv_cache)
return kv_size_bytes
# for disagg
def get_contiguous_buf_infos(self):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
kv_item_lens = [
self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def maybe_get_custom_mem_pool(self):
return self.custom_mem_pool
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
return self.kv_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.kv_buffer[layer_id - self.start_layer][
..., : self.kv_lora_rank
].view(self.dtype)
return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank]
def get_kv_buffer(self, layer_id: int):
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
assert not (self.use_nsa and self.nsa_kv_cache_store_fp8)
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
if self.store_dtype != self.dtype:
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
self.store_dtype
)
else:
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
def set_mla_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k_nope: torch.Tensor,
cache_k_rope: torch.Tensor,
):
layer_id = layer.layer_id
if self.use_nsa and self.nsa_kv_cache_store_fp8:
# original cache_k: (num_tokens, num_heads 1, hidden 576); we unsqueeze the page_size=1 dim here
# TODO no need to cat
cache_k = torch.cat([cache_k_nope, cache_k_rope], dim=-1)
cache_k = quantize_k_cache(cache_k.unsqueeze(1)).squeeze(1)
cache_k = cache_k.view(self.store_dtype)
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
else:
if cache_k_nope.dtype != self.dtype:
cache_k_nope = cache_k_nope.to(self.dtype)
cache_k_rope = cache_k_rope.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k_nope = cache_k_nope.view(self.store_dtype)
cache_k_rope = cache_k_rope.view(self.store_dtype)
set_mla_kv_buffer_triton(
self.kv_buffer[layer_id - self.start_layer],
loc,
cache_k_nope,
cache_k_rope,
)
def get_cpu_copy(self, indices):
torch.cuda.synchronize()
kv_cache_cpu = []
chunk_size = self.cpu_offloading_chunk_size
for layer_id in range(self.layer_num):
kv_cache_cpu.append([])
for i in range(0, len(indices), chunk_size):
chunk_indices = indices[i : i + chunk_size]
kv_cpu = self.kv_buffer[layer_id][chunk_indices].to(
"cpu", non_blocking=True
)
kv_cache_cpu[-1].append(kv_cpu)
torch.cuda.synchronize()
return kv_cache_cpu
def load_cpu_copy(self, kv_cache_cpu, indices):
torch.cuda.synchronize()
chunk_size = self.cpu_offloading_chunk_size
for layer_id in range(self.layer_num):
for i in range(0, len(indices), chunk_size):
chunk_indices = indices[i : i + chunk_size]
kv_cpu = kv_cache_cpu[layer_id][i // chunk_size]
assert kv_cpu.shape[0] == len(chunk_indices)
kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True)
self.kv_buffer[layer_id][chunk_indices] = kv_chunk
torch.cuda.synchronize()
class NSATokenToKVPool(MLATokenToKVPool):
quant_block_size = 128
index_k_with_scale_buffer_dtype = torch.uint8
def __init__(
self,
size: int,
page_size: int,
kv_lora_rank: int,
dtype: torch.dtype,
qk_rope_head_dim: int,
layer_num: int,
device: str,
index_head_dim: int,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
super().__init__(
size,
page_size,
dtype,
kv_lora_rank,
qk_rope_head_dim,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
use_nsa=True,
)
# self.index_k_dtype = torch.float8_e4m3fn
# self.index_k_scale_dtype = torch.float32
self.index_head_dim = index_head_dim
# num head == 1 and head dim == 128 for index_k in NSA
assert index_head_dim == 128
assert self.page_size == 64
self.index_k_with_scale_buffer = [
torch.zeros(
# Layout:
# ref: test_attention.py :: kv_cache_cast_to_fp8
# shape: (num_pages, page_size 64 * head_dim 128 + page_size 64 * fp32_nbytes 4)
# data: for page i,
# * buf[i, :page_size * head_dim] for fp8 data
# * buf[i, page_size * head_dim:].view(float32) for scale
(
(size + page_size + 1) // self.page_size,
self.page_size
* (index_head_dim + index_head_dim // self.quant_block_size * 4),
),
dtype=self.index_k_with_scale_buffer_dtype,
device=device,
)
for _ in range(layer_num)
]
self._finalize_allocation_log(size)
def get_index_k_with_scale_buffer(self, layer_id: int) -> torch.Tensor:
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return self.index_k_with_scale_buffer[layer_id - self.start_layer]
def get_index_k_continuous(
self,
layer_id: int,
seq_len: int,
page_indices: torch.Tensor,
):
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
return index_buf_accessor.GetK.execute(
self, buf, seq_len=seq_len, page_indices=page_indices
)
def get_index_k_scale_continuous(
self,
layer_id: int,
seq_len: int,
page_indices: torch.Tensor,
):
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
return index_buf_accessor.GetS.execute(
self, buf, seq_len=seq_len, page_indices=page_indices
)
# TODO rename later (currently use diff name to avoid confusion)
def set_index_k_and_scale_buffer(
self,
layer_id: int,
loc: torch.Tensor,
index_k: torch.Tensor,
index_k_scale: torch.Tensor,
) -> None:
buf = self.index_k_with_scale_buffer[layer_id - self.start_layer]
index_buf_accessor.SetKAndS.execute(
pool=self, buf=buf, loc=loc, index_k=index_k, index_k_scale=index_k_scale
)
def get_kv_size_bytes(self):
kv_size_bytes = super().get_kv_size_bytes()
for index_k_cache in self.index_k_with_scale_buffer:
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
return kv_size_bytes
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
kv_lora_rank: int,
qk_rope_head_dim: int,
index_head_dim: Optional[int],
layer_num: int,
device: str,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
super(MLATokenToKVPool, self).__init__(
size,
page_size,
dtype,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
)
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.index_head_dim = index_head_dim
self.custom_mem_pool = None
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = torch.zeros(
(
layer_num,
self.size // self.page_size + 1,
self.page_size,
1,
self.kv_lora_rank,
),
dtype=self.store_dtype,
device=self.device,
)
self.v_buffer = torch.zeros(
(
layer_num,
self.size // self.page_size + 1,
self.page_size,
1,
self.qk_rope_head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
if self.index_head_dim is not None:
self.index_k_buffer = torch.zeros(
(
layer_num,
self.size // self.page_size + 1,
self.page_size,
1,
self.index_head_dim,
),
dtype=self.store_dtype,
device=self.device,
)
self._finalize_allocation_log(size)
def get_kv_size_bytes(self):
assert hasattr(self, "k_buffer")
assert hasattr(self, "v_buffer")
kv_size_bytes = 0
for k_cache in self.k_buffer:
kv_size_bytes += get_tensor_size_bytes(k_cache)
for v_cache in self.v_buffer:
kv_size_bytes += get_tensor_size_bytes(v_cache)
if self.index_head_dim is not None:
assert hasattr(self, "index_k_buffer")
for index_k_cache in self.index_k_buffer:
kv_size_bytes += get_tensor_size_bytes(index_k_cache)
return kv_size_bytes
def get_kv_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
return (
self.k_buffer[layer_id - self.start_layer],
self.v_buffer[layer_id - self.start_layer],
)
def get_key_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.k_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
return self.v_buffer[layer_id - self.start_layer]
def get_index_k_buffer(self, layer_id: int):
if self.layer_transfer_counter is not None:
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
if self.store_dtype != self.dtype:
return self.index_k_buffer[layer_id - self.start_layer].view(self.dtype)
return self.index_k_buffer[layer_id - self.start_layer]
# for disagg
def get_contiguous_buf_infos(self):
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [
self.v_buffer[i].data_ptr() for i in range(self.layer_num)
]
kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [
self.v_buffer[i].nbytes for i in range(self.layer_num)
]
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
]
if self.index_head_dim is not None:
kv_data_ptrs += [
self.index_k_buffer[i].data_ptr() for i in range(self.layer_num)
]
kv_data_lens += [
self.index_k_buffer[i].nbytes for i in range(self.layer_num)
]
kv_item_lens += [
self.index_k_buffer[i][0].nbytes for i in range(self.layer_num)
]
return kv_data_ptrs, kv_data_lens, kv_item_lens
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
):
layer_id = layer.layer_id
if cache_k.dtype != self.dtype:
cache_k = cache_k.to(self.dtype)
cache_v = cache_v.to(self.dtype)
if self.store_dtype != self.dtype:
cache_k = cache_k.view(self.store_dtype)
cache_v = cache_v.view(self.store_dtype)
if cache_v is None:
cache_k, cache_v = cache_k.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
torch_npu.npu_scatter_nd_update_(
self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
loc.view(-1, 1),
cache_k.view(-1, 1, self.kv_lora_rank),
)
torch_npu.npu_scatter_nd_update_(
self.v_buffer[layer_id - self.start_layer].view(
-1, 1, self.qk_rope_head_dim
),
loc.view(-1, 1),
cache_v.view(-1, 1, self.qk_rope_head_dim),
)
def set_index_k_buffer(
self,
layer_id: int,
loc: torch.Tensor,
index_k: torch.Tensor,
):
if index_k.dtype != self.dtype:
index_k = index_k.to(self.dtype)
if self.store_dtype != self.dtype:
index_k = index_k.view(self.store_dtype)
torch_npu.npu_scatter_nd_update_(
self.index_k_buffer[layer_id - self.start_layer].view(
-1, 1, self.index_head_dim
),
loc.view(-1, 1),
index_k.view(-1, 1, self.index_head_dim),
)
class DoubleSparseTokenToKVPool(KVCache):
def __init__(
self,
size: int,
page_size: int,
dtype: torch.dtype,
head_num: int,
head_dim: int,
layer_num: int,
device: str,
heavy_channel_num: int,
enable_memory_saver: bool,
start_layer: Optional[int] = None,
end_layer: Optional[int] = None,
):
super().__init__(
size,
page_size,
dtype,
layer_num,
device,
enable_memory_saver,
start_layer,
end_layer,
)
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
# [size, head_num, head_dim] for each layer
self.k_buffer = [
torch.zeros(
(size + page_size, head_num, head_dim), dtype=dtype, device=device
)
for _ in range(layer_num)
]
self.v_buffer = [
torch.zeros(
(size + page_size, head_num, head_dim), dtype=dtype, device=device
)
for _ in range(layer_num)
]
# [size, head_num, heavy_channel_num] for each layer
self.label_buffer = [
torch.zeros(
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
)
for _ in range(layer_num)
]
def get_key_buffer(self, layer_id: int):
return self.k_buffer[layer_id - self.start_layer]
def get_value_buffer(self, layer_id: int):
return self.v_buffer[layer_id - self.start_layer]
def get_label_buffer(self, layer_id: int):
return self.label_buffer[layer_id - self.start_layer]
def get_kv_buffer(self, layer_id: int):
return (
self.k_buffer[layer_id - self.start_layer],
self.v_buffer[layer_id - self.start_layer],
)
def set_kv_buffer(
self,
layer: RadixAttention,
loc: torch.Tensor,
cache_k: torch.Tensor,
cache_v: torch.Tensor,
cache_label: torch.Tensor,
):
# NOTE(Andy): ignore the dtype check
layer_id = layer.layer_id
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
@triton.jit
def copy_all_layer_kv_cache_tiled(
data_ptrs,
strides,
tgt_loc_ptr,
src_loc_ptr,
num_locs,
num_locs_upper: tl.constexpr,
BYTES_PER_TILE: tl.constexpr,
):
"""2D tiled kernel. Safe for in-place copy."""
bid = tl.program_id(0)
tid = tl.program_id(1)
stride = tl.load(strides + bid)
base_ptr = tl.load(data_ptrs + bid)
base_ptr = tl.cast(base_ptr, tl.pointer_type(tl.uint8))
byte_off = tid * BYTES_PER_TILE + tl.arange(0, BYTES_PER_TILE)
mask_byte = byte_off < stride
tl.multiple_of(byte_off, 16)
loc_idx = tl.arange(0, num_locs_upper)
mask_loc = loc_idx < num_locs
src = tl.load(src_loc_ptr + loc_idx, mask=mask_loc, other=0)
tgt = tl.load(tgt_loc_ptr + loc_idx, mask=mask_loc, other=0)
src_ptr = base_ptr + src[:, None] * stride + byte_off[None, :]
tgt_ptr = base_ptr + tgt[:, None] * stride + byte_off[None, :]
mask = mask_loc[:, None] & mask_byte[None, :]
vals = tl.load(src_ptr, mask=mask)
tl.store(tgt_ptr, vals, mask=mask)