252 lines
7.7 KiB
Python
252 lines
7.7 KiB
Python
"""
|
|
Copyright 2023-2024 SGLang Team
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
you may not use this file except in compliance with the License.
|
|
You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
See the License for the specific language governing permissions and
|
|
limitations under the License.
|
|
"""
|
|
|
|
"""Memory pool."""
|
|
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import List, Tuple, Union
|
|
|
|
import torch
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ReqToTokenPool:
|
|
"""A memory pool that maps a request to its token locations."""
|
|
|
|
def __init__(self, size: int, max_context_len: int, device: str):
|
|
self.size = size
|
|
self.free_slots = list(range(size))
|
|
self.req_to_token = torch.empty(
|
|
(size, max_context_len), dtype=torch.int32, device=device
|
|
)
|
|
|
|
def alloc(self, need_size: int) -> List[int]:
|
|
if need_size > len(self.free_slots):
|
|
return None
|
|
|
|
select_index = self.free_slots[:need_size]
|
|
self.free_slots = self.free_slots[need_size:]
|
|
|
|
return select_index
|
|
|
|
def free(self, free_index: Union[int, List[int]]):
|
|
if isinstance(free_index, (int,)):
|
|
self.free_slots.append(free_index)
|
|
else:
|
|
self.free_slots.extend(free_index)
|
|
|
|
def clear(self):
|
|
self.free_slots = list(range(self.size))
|
|
|
|
|
|
class BaseTokenToKVPool(ABC):
|
|
"""A memory pool that maps a token to its kv cache locations"""
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
dtype: torch.dtype,
|
|
):
|
|
self.size = size
|
|
self.dtype = dtype
|
|
if dtype == torch.float8_e5m2:
|
|
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
|
self.store_dtype = torch.uint8
|
|
else:
|
|
self.store_dtype = dtype
|
|
|
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
|
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
|
|
|
# Prefetch buffer
|
|
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
|
self.prefetch_chunk_size = 512
|
|
|
|
self.can_use_mem_size = self.size
|
|
self.clear()
|
|
|
|
def available_size(self):
|
|
return self.can_use_mem_size + len(self.prefetch_buffer)
|
|
|
|
def alloc(self, need_size: int):
|
|
buffer_len = len(self.prefetch_buffer)
|
|
if need_size <= buffer_len:
|
|
select_index = self.prefetch_buffer[:need_size]
|
|
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
|
return select_index
|
|
|
|
addition_size = need_size - buffer_len
|
|
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
|
select_index = (
|
|
torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
|
|
)
|
|
|
|
if select_index.shape[0] < addition_size:
|
|
return None
|
|
|
|
self.mem_state[select_index] = False
|
|
self.can_use_mem_size -= len(select_index)
|
|
|
|
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
|
|
ret_index = self.prefetch_buffer[:need_size]
|
|
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
|
|
|
return ret_index
|
|
|
|
def free(self, free_index: torch.Tensor):
|
|
self.mem_state[free_index] = True
|
|
self.can_use_mem_size += len(free_index)
|
|
|
|
def clear(self):
|
|
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
|
|
|
self.mem_state.fill_(True)
|
|
self.can_use_mem_size = self.size
|
|
|
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
|
self.mem_state[0] = False
|
|
|
|
@abstractmethod
|
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
raise NotImplementedError()
|
|
|
|
@abstractmethod
|
|
def set_kv_buffer(
|
|
self,
|
|
layer_id: int,
|
|
loc: torch.Tensor,
|
|
cache_k: torch.Tensor,
|
|
cache_v: torch.Tensor,
|
|
) -> None:
|
|
raise NotImplementedError()
|
|
|
|
|
|
class MHATokenToKVPool(BaseTokenToKVPool):
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
dtype: torch.dtype,
|
|
head_num: int,
|
|
head_dim: int,
|
|
layer_num: int,
|
|
):
|
|
super().__init__(size, dtype)
|
|
|
|
# [size, head_num, head_dim] for each layer
|
|
self.k_buffer = [
|
|
torch.empty(
|
|
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
|
)
|
|
for _ in range(layer_num)
|
|
]
|
|
self.v_buffer = [
|
|
torch.empty(
|
|
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
|
)
|
|
for _ in range(layer_num)
|
|
]
|
|
|
|
def get_key_buffer(self, layer_id: int):
|
|
if self.store_dtype != self.dtype:
|
|
return self.k_buffer[layer_id].view(self.dtype)
|
|
return self.k_buffer[layer_id]
|
|
|
|
def get_value_buffer(self, layer_id: int):
|
|
if self.store_dtype != self.dtype:
|
|
return self.v_buffer[layer_id].view(self.dtype)
|
|
return self.v_buffer[layer_id]
|
|
|
|
def get_kv_buffer(self, layer_id: int):
|
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
|
|
|
def set_kv_buffer(
|
|
self,
|
|
layer_id: int,
|
|
loc: torch.Tensor,
|
|
cache_k: torch.Tensor,
|
|
cache_v: torch.Tensor,
|
|
):
|
|
if cache_k.dtype != self.dtype:
|
|
cache_k = cache_k.to(self.dtype)
|
|
if cache_v.dtype != self.dtype:
|
|
cache_v = cache_v.to(self.dtype)
|
|
if self.store_dtype != self.dtype:
|
|
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
|
self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
|
|
else:
|
|
self.k_buffer[layer_id][loc] = cache_k
|
|
self.v_buffer[layer_id][loc] = cache_v
|
|
|
|
|
|
class MLATokenToKVPool(BaseTokenToKVPool):
|
|
|
|
def __init__(
|
|
self,
|
|
size: int,
|
|
dtype: torch.dtype,
|
|
kv_lora_rank: int,
|
|
qk_rope_head_dim: int,
|
|
layer_num: int,
|
|
):
|
|
super().__init__(size, dtype)
|
|
|
|
self.kv_lora_rank = kv_lora_rank
|
|
self.kv_buffer = [
|
|
torch.empty(
|
|
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
|
dtype=self.store_dtype,
|
|
device="cuda",
|
|
)
|
|
for _ in range(layer_num)
|
|
]
|
|
|
|
def get_key_buffer(self, layer_id: int):
|
|
if self.store_dtype != self.dtype:
|
|
return self.kv_buffer[layer_id].view(self.dtype)
|
|
return self.kv_buffer[layer_id]
|
|
|
|
def get_value_buffer(self, layer_id: int):
|
|
if self.store_dtype != self.dtype:
|
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype)
|
|
return self.kv_buffer[layer_id][..., : self.kv_lora_rank]
|
|
|
|
def get_kv_buffer(self, layer_id: int):
|
|
return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id)
|
|
|
|
def set_kv_buffer(
|
|
self,
|
|
layer_id: int,
|
|
loc: torch.Tensor,
|
|
cache_k: torch.Tensor,
|
|
cache_v: torch.Tensor,
|
|
):
|
|
if cache_k.dtype != self.dtype:
|
|
cache_k = cache_k.to(self.dtype)
|
|
if self.store_dtype != self.dtype:
|
|
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
|
else:
|
|
self.kv_buffer[layer_id][loc] = cache_k
|