1444 lines
49 KiB
Python
1444 lines
49 KiB
Python
"""
|
|
Copyright 2023-2024 SGLang Team
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
|
|
"""
|
|
Memory pool.
|
|
|
|
SGLang has two levels of memory pool.
|
|
ReqToTokenPool maps a request to its token locations.
|
|
TokenToKVPoolAllocator manages the indices to kv cache data.
|
|
KVCache actually holds the physical kv cache.
|
|
"""
|
|
|
|
import abc
|
|
import logging
|
|
from contextlib import nullcontext
|
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
|
|
from sglang.srt.layers.radix_attention import RadixAttention
|
|
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.managers.cache_controller import LayerDoneCounter
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
GB = 1024 * 1024 * 1024
|
|
_is_cuda = is_cuda()
|
|
_is_npu = is_npu()
|
|
if _is_npu:
|
|
import torch_npu
|
|
|
|
|
|
def get_tensor_size_bytes(t: torch.Tensor):
|
|
return np.prod(t.shape) * t.dtype.itemsize
|
|
|
|
|
|
class ReqToTokenPool:
|
|
"""A memory pool that maps a request to its token locations."""
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
max_context_len: int,
|
|
device: str,
|
|
enable_memory_saver: bool,
|
|
):
|
|
|
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
|
enable=enable_memory_saver
|
|
)
|
|
|
|
self.size = size
|
|
self.max_context_len = max_context_len
|
|
self.device = device
|
|
with memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
|
self.req_to_token = torch.zeros(
|
|
(size, max_context_len), dtype=torch.int32, device=device
|
|
)
|
|
|
|
self.free_slots = list(range(size))
|
|
|
|
def write(self, indices, values):
|
|
self.req_to_token[indices] = values
|
|
|
|
def available_size(self):
|
|
return len(self.free_slots)
|
|
|
|
def alloc(self, need_size: int) -> List[int]:
|
|
if need_size > len(self.free_slots):
|
|
return None
|
|
|
|
select_index = self.free_slots[:need_size]
|
|
self.free_slots = self.free_slots[need_size:]
|
|
|
|
return select_index
|
|
|
|
def free(self, free_index: Union[int, List[int]]):
|
|
if isinstance(free_index, (int,)):
|
|
self.free_slots.append(free_index)
|
|
else:
|
|
self.free_slots.extend(free_index)
|
|
|
|
def clear(self):
|
|
self.free_slots = list(range(self.size))
|
|
|
|
|
|
class MambaPool:
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
conv_dtype: torch.dtype,
|
|
ssm_dtype: torch.dtype,
|
|
num_mamba_layers: int,
|
|
conv_state_shape: Tuple[int, int],
|
|
temporal_state_shape: Tuple[int, int],
|
|
device: str,
|
|
speculative_num_draft_tokens: Optional[int] = None,
|
|
):
|
|
conv_state = torch.zeros(
|
|
size=(num_mamba_layers, size + 1) + conv_state_shape,
|
|
dtype=conv_dtype,
|
|
device=device,
|
|
)
|
|
temporal_state = torch.zeros(
|
|
size=(num_mamba_layers, size + 1) + temporal_state_shape,
|
|
dtype=ssm_dtype,
|
|
device=device,
|
|
)
|
|
if speculative_num_draft_tokens is not None:
|
|
# Cache intermediate SSM states per draft token during target verify
|
|
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, HV, K, V]
|
|
intermediate_ssm_state_cache = torch.zeros(
|
|
size=(
|
|
num_mamba_layers,
|
|
size + 1,
|
|
speculative_num_draft_tokens,
|
|
temporal_state_shape[0],
|
|
temporal_state_shape[1],
|
|
temporal_state_shape[2],
|
|
),
|
|
dtype=ssm_dtype,
|
|
device="cuda",
|
|
)
|
|
# Cache intermediate conv windows (last K-1 inputs) per draft token during target verify
|
|
# Shape: [num_layers, size + 1, speculative_num_draft_tokens, dim, K-1]
|
|
intermediate_conv_window_cache = torch.zeros(
|
|
size=(
|
|
num_mamba_layers,
|
|
size + 1,
|
|
speculative_num_draft_tokens,
|
|
conv_state_shape[0],
|
|
conv_state_shape[1],
|
|
),
|
|
dtype=conv_dtype,
|
|
device="cuda",
|
|
)
|
|
self.mamba_cache = (
|
|
conv_state,
|
|
temporal_state,
|
|
intermediate_ssm_state_cache,
|
|
intermediate_conv_window_cache,
|
|
)
|
|
logger.info(
|
|
f"Mamba Cache is allocated. "
|
|
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
|
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
|
f"intermediate_ssm_state_cache size: {get_tensor_size_bytes(intermediate_ssm_state_cache) / GB:.2f}GB "
|
|
f"intermediate_conv_window_cache size: {get_tensor_size_bytes(intermediate_conv_window_cache) / GB:.2f}GB "
|
|
)
|
|
else:
|
|
self.mamba_cache = (conv_state, temporal_state)
|
|
logger.info(
|
|
f"Mamba Cache is allocated. "
|
|
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
|
|
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
|
|
)
|
|
self.size = size
|
|
self.free_slots = list(range(size))
|
|
self.mem_usage = self.get_mamba_size() / GB
|
|
|
|
def get_mamba_params_all_layers(self):
|
|
return [self.mamba_cache[i] for i in range(len(self.mamba_cache))]
|
|
|
|
def get_mamba_params(self, layer_id: int):
|
|
return [self.mamba_cache[i][layer_id] for i in range(len(self.mamba_cache))]
|
|
|
|
def get_mamba_size(self):
|
|
return sum(get_tensor_size_bytes(t) for t in self.mamba_cache)
|
|
|
|
def available_size(self):
|
|
return len(self.free_slots)
|
|
|
|
def alloc(self, need_size: int) -> Optional[List[int]]:
|
|
if need_size > len(self.free_slots):
|
|
return None
|
|
|
|
select_index = self.free_slots[:need_size]
|
|
self.free_slots = self.free_slots[need_size:]
|
|
|
|
return select_index
|
|
|
|
def free(self, free_index: Union[int, List[int]]):
|
|
if isinstance(free_index, (int,)):
|
|
self.free_slots.append(free_index)
|
|
else:
|
|
self.free_slots.extend(free_index)
|
|
self.mamba_cache[0][:, free_index] = self.mamba_cache[1][:, free_index] = 0
|
|
|
|
def clear(self):
|
|
self.free_slots = list(range(self.size))
|
|
|
|
|
|
class HybridReqToTokenPool(ReqToTokenPool):
|
|
"""A memory pool that maps a request to its token locations."""
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
max_context_len: int,
|
|
device: str,
|
|
enable_memory_saver: bool,
|
|
conv_dtype: torch.dtype,
|
|
ssm_dtype: torch.dtype,
|
|
mamba_layers: List[int],
|
|
conv_state_shape: Tuple[int, int],
|
|
temporal_state_shape: Tuple[int, int],
|
|
speculative_num_draft_tokens: int,
|
|
):
|
|
super().__init__(
|
|
size=size,
|
|
max_context_len=max_context_len,
|
|
device=device,
|
|
enable_memory_saver=enable_memory_saver,
|
|
)
|
|
|
|
self.mamba_pool = MambaPool(
|
|
size,
|
|
conv_dtype,
|
|
ssm_dtype,
|
|
len(mamba_layers),
|
|
conv_state_shape,
|
|
temporal_state_shape,
|
|
device,
|
|
speculative_num_draft_tokens,
|
|
)
|
|
self.mamba_map = {layer_id: i for i, layer_id in enumerate(mamba_layers)}
|
|
|
|
self.device = device
|
|
self.req_index_to_mamba_index_mapping: torch.Tensor = torch.zeros(
|
|
size, dtype=torch.int32, device=self.device
|
|
)
|
|
|
|
self.rid_to_mamba_index_mapping: Dict[str, int] = {}
|
|
self.mamba_index_to_rid_mapping: Dict[int, str] = {}
|
|
|
|
# For chunk prefill req, we do not need to allocate mamba cache,
|
|
# We could use allocated mamba cache instead.
|
|
def alloc(
|
|
self, need_size: int, reqs: Optional[List["Req"]] = None
|
|
) -> Optional[List[int]]:
|
|
select_index = super().alloc(need_size)
|
|
if select_index == None:
|
|
return None
|
|
|
|
mamba_index = []
|
|
for req in reqs:
|
|
rid = req.rid
|
|
if rid in self.rid_to_mamba_index_mapping:
|
|
mid = self.rid_to_mamba_index_mapping[rid]
|
|
elif (mid := self.mamba_pool.alloc(1)) is not None:
|
|
mid = mid[0]
|
|
self.rid_to_mamba_index_mapping[rid] = mid
|
|
self.mamba_index_to_rid_mapping[mid] = rid
|
|
mamba_index.append(mid)
|
|
assert len(select_index) == len(
|
|
mamba_index
|
|
), f"Not enough space for mamba cache, try to increase --max-mamba-cache-size."
|
|
self.req_index_to_mamba_index_mapping[select_index] = torch.tensor(
|
|
mamba_index, dtype=torch.int32, device=self.device
|
|
)
|
|
return select_index
|
|
|
|
def get_mamba_indices(self, req_indices: torch.Tensor) -> torch.Tensor:
|
|
return self.req_index_to_mamba_index_mapping[req_indices]
|
|
|
|
def get_mamba_params(self, layer_id: int):
|
|
assert layer_id in self.mamba_map
|
|
return self.mamba_pool.get_mamba_params(self.mamba_map[layer_id])
|
|
|
|
def get_mamba_params_all_layers(self):
|
|
return self.mamba_pool.get_mamba_params_all_layers()
|
|
|
|
# For chunk prefill, we can not free mamba cache, we need use it in the future
|
|
def free(self, free_index: Union[int, List[int]], free_mamba_cache: bool = True):
|
|
super().free(free_index)
|
|
if free_mamba_cache:
|
|
mamba_index = self.req_index_to_mamba_index_mapping[free_index]
|
|
mamba_index_list = mamba_index.tolist()
|
|
if isinstance(mamba_index_list, int):
|
|
mamba_index_list = [mamba_index_list]
|
|
self.mamba_pool.free(mamba_index_list)
|
|
for mid in mamba_index_list:
|
|
rid = self.mamba_index_to_rid_mapping[mid]
|
|
self.mamba_index_to_rid_mapping.pop(mid)
|
|
self.rid_to_mamba_index_mapping.pop(rid)
|
|
|
|
def clear(self):
|
|
super().clear()
|
|
self.mamba_pool.clear()
|
|
|
|
|
|
class KVCache(abc.ABC):
|
|
@abc.abstractmethod
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
page_size: int,
|
|
dtype: torch.dtype,
|
|
layer_num: int,
|
|
device: str,
|
|
enable_memory_saver: bool,
|
|
start_layer: Optional[int] = None,
|
|
end_layer: Optional[int] = None,
|
|
):
|
|
self.size = size
|
|
self.page_size = page_size
|
|
self.dtype = dtype
|
|
self.device = device
|
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
|
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
|
self.store_dtype = torch.uint8
|
|
else:
|
|
self.store_dtype = dtype
|
|
self.layer_num = layer_num
|
|
self.start_layer = start_layer or 0
|
|
self.end_layer = end_layer or layer_num - 1
|
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
|
enable=enable_memory_saver
|
|
)
|
|
self.mem_usage = 0
|
|
|
|
# used for chunked cpu-offloading
|
|
self.cpu_offloading_chunk_size = 8192
|
|
|
|
# default state for optional layer-wise transfer control
|
|
self.layer_transfer_counter = None
|
|
|
|
def _finalize_allocation_log(self, num_tokens: int):
|
|
"""Common logging and mem_usage computation for KV cache allocation.
|
|
Supports both tuple (K, V) size returns and single KV size returns.
|
|
"""
|
|
kv_size_bytes = self.get_kv_size_bytes()
|
|
if isinstance(kv_size_bytes, tuple):
|
|
k_size, v_size = kv_size_bytes
|
|
k_size_GB = k_size / GB
|
|
v_size_GB = v_size / GB
|
|
logger.info(
|
|
f"KV Cache is allocated. #tokens: {num_tokens}, K size: {k_size_GB:.2f} GB, V size: {v_size_GB:.2f} GB"
|
|
)
|
|
self.mem_usage = k_size_GB + v_size_GB
|
|
else:
|
|
kv_size_GB = kv_size_bytes / GB
|
|
logger.info(
|
|
f"KV Cache is allocated. #tokens: {num_tokens}, KV size: {kv_size_GB:.2f} GB"
|
|
)
|
|
self.mem_usage = kv_size_GB
|
|
|
|
@abc.abstractmethod
|
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@abc.abstractmethod
|
|
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@abc.abstractmethod
|
|
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
raise NotImplementedError()
|
|
|
|
@abc.abstractmethod
|
|
def set_kv_buffer(
|
|
self,
|
|
layer: RadixAttention,
|
|
loc: torch.Tensor,
|
|
cache_k: torch.Tensor,
|
|
cache_v: torch.Tensor,
|
|
) -> None:
|
|
raise NotImplementedError()
|
|
|
|
def register_layer_transfer_counter(self, layer_transfer_counter: LayerDoneCounter):
|
|
self.layer_transfer_counter = layer_transfer_counter
|
|
|
|
def get_cpu_copy(self, indices):
|
|
raise NotImplementedError()
|
|
|
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
|
raise NotImplementedError()
|
|
|
|
|
|
class MHATokenToKVPool(KVCache):
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
page_size: int,
|
|
dtype: torch.dtype,
|
|
head_num: int,
|
|
head_dim: int,
|
|
layer_num: int,
|
|
device: str,
|
|
enable_memory_saver: bool,
|
|
start_layer: Optional[int] = None,
|
|
end_layer: Optional[int] = None,
|
|
):
|
|
super().__init__(
|
|
size,
|
|
page_size,
|
|
dtype,
|
|
layer_num,
|
|
device,
|
|
enable_memory_saver,
|
|
start_layer,
|
|
end_layer,
|
|
)
|
|
self.head_num = head_num
|
|
self.head_dim = head_dim
|
|
|
|
# for disagg with nvlink
|
|
self.enable_custom_mem_pool = get_bool_env_var(
|
|
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
|
)
|
|
if self.enable_custom_mem_pool:
|
|
# TODO(shangming): abstract custom allocator class for more backends
|
|
from mooncake.allocator import NVLinkAllocator
|
|
|
|
allocator = NVLinkAllocator.get_allocator(self.device)
|
|
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
|
else:
|
|
self.custom_mem_pool = None
|
|
|
|
self._create_buffers()
|
|
|
|
self.device_module = torch.get_device_module(self.device)
|
|
self.alt_stream = self.device_module.Stream() if _is_cuda else None
|
|
self._finalize_allocation_log(size)
|
|
|
|
def _create_buffers(self):
|
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
|
with (
|
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
|
if self.enable_custom_mem_pool
|
|
else nullcontext()
|
|
):
|
|
# [size, head_num, head_dim] for each layer
|
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
|
self.k_buffer = [
|
|
torch.zeros(
|
|
(self.size + self.page_size, self.head_num, self.head_dim),
|
|
dtype=self.store_dtype,
|
|
device=self.device,
|
|
)
|
|
for _ in range(self.layer_num)
|
|
]
|
|
self.v_buffer = [
|
|
torch.zeros(
|
|
(self.size + self.page_size, self.head_num, self.head_dim),
|
|
dtype=self.store_dtype,
|
|
device=self.device,
|
|
)
|
|
for _ in range(self.layer_num)
|
|
]
|
|
|
|
self.k_data_ptrs = torch.tensor(
|
|
[x.data_ptr() for x in self.k_buffer],
|
|
dtype=torch.uint64,
|
|
device=self.device,
|
|
)
|
|
self.v_data_ptrs = torch.tensor(
|
|
[x.data_ptr() for x in self.v_buffer],
|
|
dtype=torch.uint64,
|
|
device=self.device,
|
|
)
|
|
self.data_ptrs = torch.cat([self.k_data_ptrs, self.v_data_ptrs], dim=0)
|
|
self.data_strides = torch.tensor(
|
|
[
|
|
np.prod(x.shape[1:]) * x.dtype.itemsize
|
|
for x in self.k_buffer + self.v_buffer
|
|
],
|
|
device=self.device,
|
|
)
|
|
|
|
def _clear_buffers(self):
|
|
del self.k_buffer
|
|
del self.v_buffer
|
|
|
|
def get_kv_size_bytes(self):
|
|
assert hasattr(self, "k_buffer")
|
|
assert hasattr(self, "v_buffer")
|
|
k_size_bytes = 0
|
|
for k_cache in self.k_buffer:
|
|
k_size_bytes += get_tensor_size_bytes(k_cache)
|
|
v_size_bytes = 0
|
|
for v_cache in self.v_buffer:
|
|
v_size_bytes += get_tensor_size_bytes(v_cache)
|
|
return k_size_bytes, v_size_bytes
|
|
|
|
# for disagg
|
|
def get_contiguous_buf_infos(self):
|
|
# layer_num x [seq_len, head_num, head_dim]
|
|
# layer_num x [page_num, page_size, head_num, head_dim]
|
|
kv_data_ptrs = [
|
|
self._get_key_buffer(i).data_ptr()
|
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
|
] + [
|
|
self._get_value_buffer(i).data_ptr()
|
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
|
]
|
|
kv_data_lens = [
|
|
self._get_key_buffer(i).nbytes
|
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
|
] + [
|
|
self._get_value_buffer(i).nbytes
|
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
|
]
|
|
kv_item_lens = [
|
|
self._get_key_buffer(i)[0].nbytes * self.page_size
|
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
|
] + [
|
|
self._get_value_buffer(i)[0].nbytes * self.page_size
|
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
|
]
|
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
|
|
|
def maybe_get_custom_mem_pool(self):
|
|
return self.custom_mem_pool
|
|
|
|
def get_cpu_copy(self, indices):
|
|
torch.cuda.synchronize()
|
|
kv_cache_cpu = []
|
|
chunk_size = self.cpu_offloading_chunk_size
|
|
for layer_id in range(self.layer_num):
|
|
kv_cache_cpu.append([])
|
|
for i in range(0, len(indices), chunk_size):
|
|
chunk_indices = indices[i : i + chunk_size]
|
|
k_cpu = self.k_buffer[layer_id][chunk_indices].to(
|
|
"cpu", non_blocking=True
|
|
)
|
|
v_cpu = self.v_buffer[layer_id][chunk_indices].to(
|
|
"cpu", non_blocking=True
|
|
)
|
|
kv_cache_cpu[-1].append([k_cpu, v_cpu])
|
|
torch.cuda.synchronize()
|
|
return kv_cache_cpu
|
|
|
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
|
torch.cuda.synchronize()
|
|
chunk_size = self.cpu_offloading_chunk_size
|
|
for layer_id in range(self.layer_num):
|
|
for i in range(0, len(indices), chunk_size):
|
|
chunk_indices = indices[i : i + chunk_size]
|
|
k_cpu, v_cpu = (
|
|
kv_cache_cpu[layer_id][i // chunk_size][0],
|
|
kv_cache_cpu[layer_id][i // chunk_size][1],
|
|
)
|
|
assert k_cpu.shape[0] == v_cpu.shape[0] == len(chunk_indices)
|
|
k_chunk = k_cpu.to(self.k_buffer[0].device, non_blocking=True)
|
|
v_chunk = v_cpu.to(self.v_buffer[0].device, non_blocking=True)
|
|
self.k_buffer[layer_id][chunk_indices] = k_chunk
|
|
self.v_buffer[layer_id][chunk_indices] = v_chunk
|
|
torch.cuda.synchronize()
|
|
|
|
def _get_key_buffer(self, layer_id: int):
|
|
# for internal use of referencing
|
|
if self.store_dtype != self.dtype:
|
|
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
|
return self.k_buffer[layer_id - self.start_layer]
|
|
|
|
def get_key_buffer(self, layer_id: int):
|
|
# note: get_key_buffer is hooked with synchronization for layer-wise KV cache loading
|
|
# it is supposed to be used only by attention backend not for information purpose
|
|
# same applies to get_value_buffer and get_kv_buffer
|
|
if self.layer_transfer_counter is not None:
|
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
|
return self._get_key_buffer(layer_id)
|
|
|
|
def _get_value_buffer(self, layer_id: int):
|
|
# for internal use of referencing
|
|
if self.store_dtype != self.dtype:
|
|
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
|
return self.v_buffer[layer_id - self.start_layer]
|
|
|
|
def get_value_buffer(self, layer_id: int):
|
|
if self.layer_transfer_counter is not None:
|
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
|
return self._get_value_buffer(layer_id)
|
|
|
|
def get_kv_buffer(self, layer_id: int):
|
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
|
|
|
def set_kv_buffer(
|
|
self,
|
|
layer: RadixAttention,
|
|
loc: torch.Tensor,
|
|
cache_k: torch.Tensor,
|
|
cache_v: torch.Tensor,
|
|
k_scale: Optional[float] = None,
|
|
v_scale: Optional[float] = None,
|
|
layer_id_override: Optional[int] = None,
|
|
):
|
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
|
|
|
if layer_id_override is not None:
|
|
layer_id = layer_id_override
|
|
else:
|
|
layer_id = layer.layer_id
|
|
if cache_k.dtype != self.dtype:
|
|
if k_scale is not None:
|
|
cache_k.div_(k_scale)
|
|
if v_scale is not None:
|
|
cache_v.div_(v_scale)
|
|
cache_k = cache_k.to(self.dtype)
|
|
cache_v = cache_v.to(self.dtype)
|
|
|
|
if self.store_dtype != self.dtype:
|
|
cache_k = cache_k.view(self.store_dtype)
|
|
cache_v = cache_v.view(self.store_dtype)
|
|
|
|
if get_is_capture_mode() and self.alt_stream is not None:
|
|
# Overlap the copy of K and V cache for small batch size
|
|
current_stream = self.device_module.current_stream()
|
|
self.alt_stream.wait_stream(current_stream)
|
|
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
|
with self.device_module.stream(self.alt_stream):
|
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
|
current_stream.wait_stream(self.alt_stream)
|
|
else:
|
|
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
|
|
|
def move_kv_cache(self, tgt_loc: torch.Tensor, src_loc: torch.Tensor):
|
|
copy_all_layer_kv_cache[(len(self.data_ptrs),)](
|
|
self.data_ptrs,
|
|
self.data_strides,
|
|
tgt_loc,
|
|
src_loc,
|
|
len(tgt_loc),
|
|
next_power_of_2(len(tgt_loc)),
|
|
)
|
|
|
|
|
|
class HybridLinearKVPool(KVCache):
|
|
"""KV cache with separate pools for full and linear attention layers."""
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
dtype: torch.dtype,
|
|
page_size: int,
|
|
head_num: int,
|
|
head_dim: int,
|
|
full_attention_layer_ids: List[int],
|
|
enable_kvcache_transpose: bool,
|
|
device: str,
|
|
):
|
|
self.size = size
|
|
self.dtype = dtype
|
|
self.device = device
|
|
self.full_layer_nums = len(full_attention_layer_ids)
|
|
self.page_size = page_size
|
|
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
|
assert not enable_kvcache_transpose
|
|
if _is_npu:
|
|
TokenToKVPoolClass = AscendTokenToKVPool
|
|
else:
|
|
TokenToKVPoolClass = MHATokenToKVPool
|
|
self.full_kv_pool = TokenToKVPoolClass(
|
|
size=size,
|
|
page_size=self.page_size,
|
|
dtype=dtype,
|
|
head_num=head_num,
|
|
head_dim=head_dim,
|
|
layer_num=self.full_layer_nums,
|
|
device=device,
|
|
enable_memory_saver=False,
|
|
)
|
|
self.full_attention_layer_id_mapping = {
|
|
id: i for i, id in enumerate(full_attention_layer_ids)
|
|
}
|
|
k_size, v_size = self.get_kv_size_bytes()
|
|
self.mem_usage = (k_size + v_size) / GB
|
|
|
|
def get_kv_size_bytes(self):
|
|
return self.full_kv_pool.get_kv_size_bytes()
|
|
|
|
def get_contiguous_buf_infos(self):
|
|
return self.full_kv_pool.get_contiguous_buf_infos()
|
|
|
|
def _transfer_full_attention_id(self, layer_id: int):
|
|
if layer_id not in self.full_attention_layer_id_mapping:
|
|
raise ValueError(
|
|
f"{layer_id=} not in full attention layers: {self.full_attention_layer_id_mapping.keys()}"
|
|
)
|
|
return self.full_attention_layer_id_mapping[layer_id]
|
|
|
|
def get_key_buffer(self, layer_id: int):
|
|
layer_id = self._transfer_full_attention_id(layer_id)
|
|
return self.full_kv_pool.get_key_buffer(layer_id)
|
|
|
|
def get_value_buffer(self, layer_id: int):
|
|
layer_id = self._transfer_full_attention_id(layer_id)
|
|
return self.full_kv_pool.get_value_buffer(layer_id)
|
|
|
|
def get_kv_buffer(self, layer_id: int):
|
|
layer_id = self._transfer_full_attention_id(layer_id)
|
|
return self.full_kv_pool.get_kv_buffer(layer_id)
|
|
|
|
def set_kv_buffer(
|
|
self,
|
|
layer: RadixAttention,
|
|
loc: torch.Tensor,
|
|
cache_k: torch.Tensor,
|
|
cache_v: torch.Tensor,
|
|
k_scale: float = 1.0,
|
|
v_scale: float = 1.0,
|
|
):
|
|
layer_id = self._transfer_full_attention_id(layer.layer_id)
|
|
self.full_kv_pool.set_kv_buffer(
|
|
None,
|
|
loc,
|
|
cache_k,
|
|
cache_v,
|
|
k_scale,
|
|
v_scale,
|
|
layer_id_override=layer_id,
|
|
)
|
|
|
|
def get_v_head_dim(self):
|
|
return self.full_kv_pool.get_value_buffer(0).shape[-1]
|
|
|
|
|
|
class SWAKVPool(KVCache):
|
|
"""KV cache with separate pools for full and SWA attention layers."""
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
size_swa: int,
|
|
swa_attention_layer_ids: List[int],
|
|
full_attention_layer_ids: List[int],
|
|
enable_kvcache_transpose: bool,
|
|
token_to_kv_pool_class: KVCache = MHATokenToKVPool,
|
|
**kwargs,
|
|
):
|
|
self.size = size
|
|
self.size_swa = size_swa
|
|
self.swa_layer_nums = len(swa_attention_layer_ids)
|
|
self.full_layer_nums = len(full_attention_layer_ids)
|
|
kwargs["page_size"] = 1
|
|
kwargs["enable_memory_saver"] = False
|
|
# TODO MHATransposedTokenToKVPool if enable_kvcache_transpose is True
|
|
assert not enable_kvcache_transpose
|
|
|
|
self.swa_kv_pool = token_to_kv_pool_class(
|
|
size=size_swa,
|
|
layer_num=self.swa_layer_nums,
|
|
**kwargs,
|
|
)
|
|
self.full_kv_pool = token_to_kv_pool_class(
|
|
size=size,
|
|
layer_num=self.full_layer_nums,
|
|
**kwargs,
|
|
)
|
|
self.layers_mapping: Dict[int, Tuple[int, bool]] = {}
|
|
for full_attn_layer_id, global_layer_id in enumerate(full_attention_layer_ids):
|
|
self.layers_mapping[global_layer_id] = (full_attn_layer_id, False)
|
|
for swa_layer_id, global_layer_id in enumerate(swa_attention_layer_ids):
|
|
self.layers_mapping[global_layer_id] = (swa_layer_id, True)
|
|
self.full_to_swa_index_mapping: Optional[torch.Tensor] = None
|
|
|
|
k_size, v_size = self.get_kv_size_bytes()
|
|
self.mem_usage = (k_size + v_size) / GB
|
|
|
|
def get_kv_size_bytes(self):
|
|
k_size, v_size = self.full_kv_pool.get_kv_size_bytes()
|
|
k_size_swa, v_size_swa = self.swa_kv_pool.get_kv_size_bytes()
|
|
return k_size + k_size_swa, v_size + v_size_swa
|
|
|
|
def get_contiguous_buf_infos(self):
|
|
full_kv_data_ptrs, full_kv_data_lens, full_kv_item_lens = (
|
|
self.full_kv_pool.get_contiguous_buf_infos()
|
|
)
|
|
swa_kv_data_ptrs, swa_kv_data_lens, swa_kv_item_lens = (
|
|
self.swa_kv_pool.get_contiguous_buf_infos()
|
|
)
|
|
|
|
kv_data_ptrs = full_kv_data_ptrs + swa_kv_data_ptrs
|
|
kv_data_lens = full_kv_data_lens + swa_kv_data_lens
|
|
kv_item_lens = full_kv_item_lens + swa_kv_item_lens
|
|
|
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
|
|
|
def get_key_buffer(self, layer_id: int):
|
|
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
|
if is_swa:
|
|
return self.swa_kv_pool.get_key_buffer(layer_id_pool)
|
|
else:
|
|
return self.full_kv_pool.get_key_buffer(layer_id_pool)
|
|
|
|
def get_value_buffer(self, layer_id: int):
|
|
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
|
if is_swa:
|
|
return self.swa_kv_pool.get_value_buffer(layer_id_pool)
|
|
else:
|
|
return self.full_kv_pool.get_value_buffer(layer_id_pool)
|
|
|
|
def get_kv_buffer(self, layer_id: int):
|
|
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
|
if is_swa:
|
|
return self.swa_kv_pool.get_kv_buffer(layer_id_pool)
|
|
else:
|
|
return self.full_kv_pool.get_kv_buffer(layer_id_pool)
|
|
|
|
def translate_loc_from_full_to_swa(self, kv_indices: torch.Tensor):
|
|
assert self.full_to_swa_index_mapping is not None
|
|
return self.full_to_swa_index_mapping[kv_indices].to(torch.int32)
|
|
|
|
def set_kv_buffer(
|
|
self,
|
|
layer: RadixAttention,
|
|
loc: torch.Tensor,
|
|
cache_k: torch.Tensor,
|
|
cache_v: torch.Tensor,
|
|
k_scale: float = 1.0,
|
|
v_scale: float = 1.0,
|
|
):
|
|
|
|
layer_id = layer.layer_id
|
|
layer_id_pool, is_swa = self.layers_mapping[layer_id]
|
|
if is_swa:
|
|
if self.full_to_swa_index_mapping is not None:
|
|
loc = self.translate_loc_from_full_to_swa(loc)
|
|
self.swa_kv_pool.set_kv_buffer(
|
|
None,
|
|
loc,
|
|
cache_k,
|
|
cache_v,
|
|
k_scale,
|
|
v_scale,
|
|
layer_id_override=layer_id_pool,
|
|
)
|
|
else:
|
|
self.full_kv_pool.set_kv_buffer(
|
|
None,
|
|
loc,
|
|
cache_k,
|
|
cache_v,
|
|
k_scale,
|
|
v_scale,
|
|
layer_id_override=layer_id_pool,
|
|
)
|
|
|
|
|
|
class AscendTokenToKVPool(MHATokenToKVPool):
|
|
|
|
def _create_buffers(self):
|
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
|
# [size, head_num, head_dim] for each layer
|
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
|
# Continuous memory improves the efficiency of Ascend`s transmission backend,
|
|
# while other backends remain unchanged.
|
|
self.kv_buffer = torch.zeros(
|
|
(
|
|
2,
|
|
self.layer_num,
|
|
self.size // self.page_size + 1,
|
|
self.page_size,
|
|
self.head_num,
|
|
self.head_dim,
|
|
),
|
|
dtype=self.store_dtype,
|
|
device=self.device,
|
|
)
|
|
self.k_buffer = self.kv_buffer[0]
|
|
self.v_buffer = self.kv_buffer[1]
|
|
|
|
# for disagg
|
|
def get_contiguous_buf_infos(self):
|
|
# layer_num x [seq_len, head_num, head_dim]
|
|
# layer_num x [page_num, page_size, head_num, head_dim]
|
|
kv_data_ptrs = [
|
|
self.get_key_buffer(i).data_ptr()
|
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
|
] + [
|
|
self.get_value_buffer(i).data_ptr()
|
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
|
]
|
|
kv_data_lens = [
|
|
self.get_key_buffer(i).nbytes
|
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
|
] + [
|
|
self.get_value_buffer(i).nbytes
|
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
|
]
|
|
kv_item_lens = [
|
|
self.get_key_buffer(i)[0].nbytes
|
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
|
] + [
|
|
self.get_value_buffer(i)[0].nbytes
|
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
|
]
|
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
|
|
|
def set_kv_buffer(
|
|
self,
|
|
layer: RadixAttention,
|
|
loc: torch.Tensor,
|
|
cache_k: torch.Tensor,
|
|
cache_v: torch.Tensor,
|
|
k_scale: Optional[float] = None,
|
|
v_scale: Optional[float] = None,
|
|
layer_id_override: Optional[int] = None,
|
|
):
|
|
if layer_id_override is not None:
|
|
layer_id = layer_id_override
|
|
else:
|
|
layer_id = layer.layer_id
|
|
if cache_k.dtype != self.dtype:
|
|
if k_scale is not None:
|
|
cache_k.div_(k_scale)
|
|
if v_scale is not None:
|
|
cache_v.div_(v_scale)
|
|
cache_k = cache_k.to(self.dtype)
|
|
cache_v = cache_v.to(self.dtype)
|
|
|
|
if self.store_dtype != self.dtype:
|
|
cache_k = cache_k.view(self.store_dtype)
|
|
cache_v = cache_v.view(self.store_dtype)
|
|
|
|
torch_npu._npu_reshape_and_cache(
|
|
key=cache_k,
|
|
value=cache_v,
|
|
key_cache=self.k_buffer[layer_id].view(
|
|
-1, self.page_size, self.head_num, self.head_dim
|
|
),
|
|
value_cache=self.v_buffer[layer_id].view(
|
|
-1, self.page_size, self.head_num, self.head_dim
|
|
),
|
|
slot_indices=loc,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def set_mla_kv_buffer_kernel(
|
|
kv_buffer_ptr,
|
|
cache_k_nope_ptr,
|
|
cache_k_rope_ptr,
|
|
loc_ptr,
|
|
buffer_stride: tl.constexpr,
|
|
nope_stride: tl.constexpr,
|
|
rope_stride: tl.constexpr,
|
|
nope_dim: tl.constexpr,
|
|
rope_dim: tl.constexpr,
|
|
BLOCK: tl.constexpr,
|
|
):
|
|
pid_loc = tl.program_id(0)
|
|
pid_blk = tl.program_id(1)
|
|
|
|
base = pid_blk * BLOCK
|
|
offs = base + tl.arange(0, BLOCK)
|
|
total_dim = nope_dim + rope_dim
|
|
mask = offs < total_dim
|
|
|
|
loc = tl.load(loc_ptr + pid_loc)
|
|
dst_ptr = kv_buffer_ptr + loc * buffer_stride + offs
|
|
|
|
if base + BLOCK <= nope_dim:
|
|
src = tl.load(
|
|
cache_k_nope_ptr + pid_loc * nope_stride + offs,
|
|
mask=mask,
|
|
)
|
|
else:
|
|
offs_rope = offs - nope_dim
|
|
src = tl.load(
|
|
cache_k_rope_ptr + pid_loc * rope_stride + offs_rope,
|
|
mask=mask,
|
|
)
|
|
|
|
tl.store(dst_ptr, src, mask=mask)
|
|
|
|
|
|
def set_mla_kv_buffer_triton(
|
|
kv_buffer: torch.Tensor,
|
|
loc: torch.Tensor,
|
|
cache_k_nope: torch.Tensor,
|
|
cache_k_rope: torch.Tensor,
|
|
):
|
|
nope_dim = cache_k_nope.shape[-1]
|
|
rope_dim = cache_k_rope.shape[-1]
|
|
total_dim = nope_dim + rope_dim
|
|
BLOCK = 128
|
|
n_loc = loc.numel()
|
|
grid = (n_loc, triton.cdiv(total_dim, BLOCK))
|
|
|
|
set_mla_kv_buffer_kernel[grid](
|
|
kv_buffer,
|
|
cache_k_nope,
|
|
cache_k_rope,
|
|
loc,
|
|
kv_buffer.stride(0),
|
|
cache_k_nope.stride(0),
|
|
cache_k_rope.stride(0),
|
|
nope_dim,
|
|
rope_dim,
|
|
BLOCK=BLOCK,
|
|
)
|
|
|
|
|
|
class MLATokenToKVPool(KVCache):
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
page_size: int,
|
|
dtype: torch.dtype,
|
|
kv_lora_rank: int,
|
|
qk_rope_head_dim: int,
|
|
layer_num: int,
|
|
device: str,
|
|
enable_memory_saver: bool,
|
|
start_layer: Optional[int] = None,
|
|
end_layer: Optional[int] = None,
|
|
):
|
|
super().__init__(
|
|
size,
|
|
page_size,
|
|
dtype,
|
|
layer_num,
|
|
device,
|
|
enable_memory_saver,
|
|
start_layer,
|
|
end_layer,
|
|
)
|
|
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
|
|
# for disagg with nvlink
|
|
self.enable_custom_mem_pool = get_bool_env_var(
|
|
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
|
)
|
|
if self.enable_custom_mem_pool:
|
|
# TODO(shangming): abstract custom allocator class for more backends
|
|
from mooncake.allocator import NVLinkAllocator
|
|
|
|
allocator = NVLinkAllocator.get_allocator(self.device)
|
|
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
|
|
else:
|
|
self.custom_mem_pool = None
|
|
|
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
|
with (
|
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
|
if self.custom_mem_pool
|
|
else nullcontext()
|
|
):
|
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
|
self.kv_buffer = [
|
|
torch.zeros(
|
|
(size + page_size, 1, kv_lora_rank + qk_rope_head_dim),
|
|
dtype=self.store_dtype,
|
|
device=device,
|
|
)
|
|
for _ in range(layer_num)
|
|
]
|
|
|
|
self.data_ptrs = torch.tensor(
|
|
[x.data_ptr() for x in self.kv_buffer],
|
|
dtype=torch.uint64,
|
|
device=self.device,
|
|
)
|
|
self._finalize_allocation_log(size)
|
|
|
|
def get_kv_size_bytes(self):
|
|
assert hasattr(self, "kv_buffer")
|
|
kv_size_bytes = 0
|
|
for kv_cache in self.kv_buffer:
|
|
kv_size_bytes += get_tensor_size_bytes(kv_cache)
|
|
return kv_size_bytes
|
|
|
|
# for disagg
|
|
def get_contiguous_buf_infos(self):
|
|
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
|
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
|
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
|
kv_item_lens = [
|
|
self.kv_buffer[i][0].nbytes * self.page_size for i in range(self.layer_num)
|
|
]
|
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
|
|
|
def maybe_get_custom_mem_pool(self):
|
|
return self.custom_mem_pool
|
|
|
|
def get_key_buffer(self, layer_id: int):
|
|
if self.layer_transfer_counter is not None:
|
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
|
|
|
if self.store_dtype != self.dtype:
|
|
return self.kv_buffer[layer_id - self.start_layer].view(self.dtype)
|
|
return self.kv_buffer[layer_id - self.start_layer]
|
|
|
|
def get_value_buffer(self, layer_id: int):
|
|
if self.layer_transfer_counter is not None:
|
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
|
|
|
if self.store_dtype != self.dtype:
|
|
return self.kv_buffer[layer_id - self.start_layer][
|
|
..., : self.kv_lora_rank
|
|
].view(self.dtype)
|
|
return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank]
|
|
|
|
def get_kv_buffer(self, layer_id: int):
|
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
|
|
|
def set_kv_buffer(
|
|
self,
|
|
layer: RadixAttention,
|
|
loc: torch.Tensor,
|
|
cache_k: torch.Tensor,
|
|
cache_v: torch.Tensor,
|
|
):
|
|
layer_id = layer.layer_id
|
|
if cache_k.dtype != self.dtype:
|
|
cache_k = cache_k.to(self.dtype)
|
|
if self.store_dtype != self.dtype:
|
|
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view(
|
|
self.store_dtype
|
|
)
|
|
else:
|
|
self.kv_buffer[layer_id - self.start_layer][loc] = cache_k
|
|
|
|
def set_mla_kv_buffer(
|
|
self,
|
|
layer: RadixAttention,
|
|
loc: torch.Tensor,
|
|
cache_k_nope: torch.Tensor,
|
|
cache_k_rope: torch.Tensor,
|
|
):
|
|
layer_id = layer.layer_id
|
|
if cache_k_nope.dtype != self.dtype:
|
|
cache_k_nope = cache_k_nope.to(self.dtype)
|
|
cache_k_rope = cache_k_rope.to(self.dtype)
|
|
if self.store_dtype != self.dtype:
|
|
cache_k_nope = cache_k_nope.view(self.store_dtype)
|
|
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
|
|
|
set_mla_kv_buffer_triton(
|
|
self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
|
|
)
|
|
|
|
def get_cpu_copy(self, indices):
|
|
torch.cuda.synchronize()
|
|
kv_cache_cpu = []
|
|
chunk_size = self.cpu_offloading_chunk_size
|
|
for layer_id in range(self.layer_num):
|
|
kv_cache_cpu.append([])
|
|
for i in range(0, len(indices), chunk_size):
|
|
chunk_indices = indices[i : i + chunk_size]
|
|
kv_cpu = self.kv_buffer[layer_id][chunk_indices].to(
|
|
"cpu", non_blocking=True
|
|
)
|
|
kv_cache_cpu[-1].append(kv_cpu)
|
|
torch.cuda.synchronize()
|
|
return kv_cache_cpu
|
|
|
|
def load_cpu_copy(self, kv_cache_cpu, indices):
|
|
torch.cuda.synchronize()
|
|
chunk_size = self.cpu_offloading_chunk_size
|
|
for layer_id in range(self.layer_num):
|
|
for i in range(0, len(indices), chunk_size):
|
|
chunk_indices = indices[i : i + chunk_size]
|
|
kv_cpu = kv_cache_cpu[layer_id][i // chunk_size]
|
|
assert kv_cpu.shape[0] == len(chunk_indices)
|
|
kv_chunk = kv_cpu.to(self.kv_buffer[0].device, non_blocking=True)
|
|
self.kv_buffer[layer_id][chunk_indices] = kv_chunk
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
page_size: int,
|
|
dtype: torch.dtype,
|
|
kv_lora_rank: int,
|
|
qk_rope_head_dim: int,
|
|
layer_num: int,
|
|
device: str,
|
|
enable_memory_saver: bool,
|
|
start_layer: Optional[int] = None,
|
|
end_layer: Optional[int] = None,
|
|
):
|
|
super(MLATokenToKVPool, self).__init__(
|
|
size,
|
|
page_size,
|
|
dtype,
|
|
layer_num,
|
|
device,
|
|
enable_memory_saver,
|
|
start_layer,
|
|
end_layer,
|
|
)
|
|
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.qk_rope_head_dim = qk_rope_head_dim
|
|
|
|
self.custom_mem_pool = None
|
|
|
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
|
self.k_buffer = torch.zeros(
|
|
(
|
|
layer_num,
|
|
self.size // self.page_size + 1,
|
|
self.page_size,
|
|
1,
|
|
self.kv_lora_rank,
|
|
),
|
|
dtype=self.store_dtype,
|
|
device=self.device,
|
|
)
|
|
self.v_buffer = torch.zeros(
|
|
(
|
|
layer_num,
|
|
self.size // self.page_size + 1,
|
|
self.page_size,
|
|
1,
|
|
self.qk_rope_head_dim,
|
|
),
|
|
dtype=self.store_dtype,
|
|
device=self.device,
|
|
)
|
|
|
|
self._finalize_allocation_log(size)
|
|
|
|
def get_kv_size_bytes(self):
|
|
assert hasattr(self, "k_buffer")
|
|
assert hasattr(self, "v_buffer")
|
|
kv_size_bytes = 0
|
|
for k_cache in self.k_buffer:
|
|
kv_size_bytes += get_tensor_size_bytes(k_cache)
|
|
for v_cache in self.v_buffer:
|
|
kv_size_bytes += get_tensor_size_bytes(v_cache)
|
|
return kv_size_bytes
|
|
|
|
def get_kv_buffer(self, layer_id: int):
|
|
if self.layer_transfer_counter is not None:
|
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
|
return (
|
|
self.k_buffer[layer_id - self.start_layer],
|
|
self.v_buffer[layer_id - self.start_layer],
|
|
)
|
|
|
|
def get_key_buffer(self, layer_id: int):
|
|
if self.layer_transfer_counter is not None:
|
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
|
|
|
if self.store_dtype != self.dtype:
|
|
return self.k_buffer[layer_id - self.start_layer].view(self.dtype)
|
|
return self.k_buffer[layer_id - self.start_layer]
|
|
|
|
def get_value_buffer(self, layer_id: int):
|
|
if self.layer_transfer_counter is not None:
|
|
self.layer_transfer_counter.wait_until(layer_id - self.start_layer)
|
|
|
|
if self.store_dtype != self.dtype:
|
|
return self.v_buffer[layer_id - self.start_layer].view(self.dtype)
|
|
return self.v_buffer[layer_id - self.start_layer]
|
|
|
|
# for disagg
|
|
def get_contiguous_buf_infos(self):
|
|
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
|
kv_data_ptrs = [self.k_buffer[i].data_ptr() for i in range(self.layer_num)] + [
|
|
self.v_buffer[i].data_ptr() for i in range(self.layer_num)
|
|
]
|
|
kv_data_lens = [self.k_buffer[i].nbytes for i in range(self.layer_num)] + [
|
|
self.v_buffer[i].nbytes for i in range(self.layer_num)
|
|
]
|
|
kv_item_lens = [self.k_buffer[i][0].nbytes for i in range(self.layer_num)] + [
|
|
self.v_buffer[i][0].nbytes for i in range(self.layer_num)
|
|
]
|
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
|
|
|
def set_kv_buffer(
|
|
self,
|
|
layer: RadixAttention,
|
|
loc: torch.Tensor,
|
|
cache_k: torch.Tensor,
|
|
cache_v: torch.Tensor,
|
|
):
|
|
layer_id = layer.layer_id
|
|
if cache_k.dtype != self.dtype:
|
|
cache_k = cache_k.to(self.dtype)
|
|
cache_v = cache_v.to(self.dtype)
|
|
|
|
if self.store_dtype != self.dtype:
|
|
cache_k = cache_k.view(self.store_dtype)
|
|
cache_v = cache_v.view(self.store_dtype)
|
|
|
|
if cache_v is None:
|
|
cache_k, cache_v = cache_k.split(
|
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
|
)
|
|
|
|
torch_npu.npu_scatter_nd_update_(
|
|
self.k_buffer[layer_id - self.start_layer].view(-1, 1, self.kv_lora_rank),
|
|
loc.view(-1, 1),
|
|
cache_k.view(-1, 1, self.kv_lora_rank),
|
|
)
|
|
torch_npu.npu_scatter_nd_update_(
|
|
self.v_buffer[layer_id - self.start_layer].view(
|
|
-1, 1, self.qk_rope_head_dim
|
|
),
|
|
loc.view(-1, 1),
|
|
cache_v.view(-1, 1, self.qk_rope_head_dim),
|
|
)
|
|
|
|
|
|
class DoubleSparseTokenToKVPool(KVCache):
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
page_size: int,
|
|
dtype: torch.dtype,
|
|
head_num: int,
|
|
head_dim: int,
|
|
layer_num: int,
|
|
device: str,
|
|
heavy_channel_num: int,
|
|
enable_memory_saver: bool,
|
|
start_layer: Optional[int] = None,
|
|
end_layer: Optional[int] = None,
|
|
):
|
|
super().__init__(
|
|
size,
|
|
page_size,
|
|
dtype,
|
|
layer_num,
|
|
device,
|
|
enable_memory_saver,
|
|
start_layer,
|
|
end_layer,
|
|
)
|
|
|
|
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
|
# [size, head_num, head_dim] for each layer
|
|
self.k_buffer = [
|
|
torch.zeros(
|
|
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
|
)
|
|
for _ in range(layer_num)
|
|
]
|
|
self.v_buffer = [
|
|
torch.zeros(
|
|
(size + page_size, head_num, head_dim), dtype=dtype, device=device
|
|
)
|
|
for _ in range(layer_num)
|
|
]
|
|
|
|
# [size, head_num, heavy_channel_num] for each layer
|
|
self.label_buffer = [
|
|
torch.zeros(
|
|
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
|
)
|
|
for _ in range(layer_num)
|
|
]
|
|
|
|
def get_key_buffer(self, layer_id: int):
|
|
return self.k_buffer[layer_id - self.start_layer]
|
|
|
|
def get_value_buffer(self, layer_id: int):
|
|
return self.v_buffer[layer_id - self.start_layer]
|
|
|
|
def get_label_buffer(self, layer_id: int):
|
|
return self.label_buffer[layer_id - self.start_layer]
|
|
|
|
def get_kv_buffer(self, layer_id: int):
|
|
return (
|
|
self.k_buffer[layer_id - self.start_layer],
|
|
self.v_buffer[layer_id - self.start_layer],
|
|
)
|
|
|
|
def set_kv_buffer(
|
|
self,
|
|
layer: RadixAttention,
|
|
loc: torch.Tensor,
|
|
cache_k: torch.Tensor,
|
|
cache_v: torch.Tensor,
|
|
cache_label: torch.Tensor,
|
|
):
|
|
# NOTE(Andy): ignore the dtype check
|
|
layer_id = layer.layer_id
|
|
self.k_buffer[layer_id - self.start_layer][loc] = cache_k
|
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
|
self.label_buffer[layer_id - self.start_layer][loc] = cache_label
|
|
|
|
|
|
@triton.jit
|
|
def copy_all_layer_kv_cache(
|
|
data_ptrs,
|
|
strides,
|
|
tgt_loc_ptr,
|
|
src_loc_ptr,
|
|
num_locs,
|
|
num_locs_upper: tl.constexpr,
|
|
):
|
|
BLOCK_SIZE: tl.constexpr = 128
|
|
|
|
bid = tl.program_id(0)
|
|
stride = tl.load(strides + bid)
|
|
|
|
data_ptr = tl.load(data_ptrs + bid)
|
|
data_ptr = tl.cast(data_ptr, tl.pointer_type(tl.uint8))
|
|
|
|
num_locs_offset = tl.arange(0, num_locs_upper)
|
|
tgt_locs = tl.load(tgt_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
|
src_locs = tl.load(src_loc_ptr + num_locs_offset, mask=num_locs_offset < num_locs)
|
|
|
|
# NOTE: we cannot parallelize over the tgt_loc_ptr dim with cuda blocks
|
|
# because this copy is an inplace operation.
|
|
|
|
num_loop = tl.cdiv(stride, BLOCK_SIZE)
|
|
for i in range(num_loop):
|
|
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
|
mask = (num_locs_offset < num_locs)[:, None] & (copy_offset < stride)[None, :]
|
|
value = tl.load(
|
|
data_ptr + src_locs[:, None] * stride + copy_offset[None, :], mask=mask
|
|
)
|
|
tl.store(
|
|
data_ptr + tgt_locs[:, None] * stride + copy_offset[None, :],
|
|
value,
|
|
mask=mask,
|
|
)
|