Move status check in the memory pool to CPU (#1557)

This commit is contained in:
Lianmin Zheng
2024-10-02 18:23:35 -07:00
committed by GitHub
parent 317631cada
commit 4ae0969c0a
2 changed files with 27 additions and 42 deletions

View File

@@ -19,6 +19,7 @@ import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Union from typing import List, Tuple, Union
import numpy as np
import torch import torch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -69,56 +70,27 @@ class BaseTokenToKVPool(ABC):
else: else:
self.store_dtype = dtype self.store_dtype = dtype
# We also add one slot. This slot is used for writing dummy output from padded tokens. self.free_slots = None
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
# Prefetch buffer
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
self.prefetch_chunk_size = 512
self.can_use_mem_size = self.size
self.clear() self.clear()
def available_size(self): def available_size(self):
return self.can_use_mem_size + len(self.prefetch_buffer) return len(self.free_slots)
def alloc(self, need_size: int): def alloc(self, need_size: int):
buffer_len = len(self.prefetch_buffer) if need_size > len(self.free_slots):
if need_size <= buffer_len:
select_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return select_index
addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = (
torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
)
if select_index.shape[0] < addition_size:
return None return None
self.mem_state[select_index] = False select_index = self.free_slots[:need_size]
self.can_use_mem_size -= len(select_index) self.free_slots = self.free_slots[need_size:]
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index)) return torch.tensor(select_index, dtype=torch.int32, device="cuda")
ret_index = self.prefetch_buffer[:need_size]
self.prefetch_buffer = self.prefetch_buffer[need_size:]
return ret_index
def free(self, free_index: torch.Tensor): def free(self, free_index: torch.Tensor):
self.mem_state[free_index] = True self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
self.can_use_mem_size += len(free_index)
def clear(self): def clear(self):
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32) # The padded slot 0 is used for writing dummy outputs from padded tokens.
self.free_slots = np.arange(1, self.size + 1)
self.mem_state.fill_(True)
self.can_use_mem_size = self.size
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state[0] = False
@abstractmethod @abstractmethod
def get_key_buffer(self, layer_id: int) -> torch.Tensor: def get_key_buffer(self, layer_id: int) -> torch.Tensor:
@@ -152,19 +124,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
head_num: int, head_num: int,
head_dim: int, head_dim: int,
layer_num: int, layer_num: int,
device: str,
): ):
super().__init__(size, dtype) super().__init__(size, dtype)
# [size, head_num, head_dim] for each layer # [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.k_buffer = [ self.k_buffer = [
torch.empty( torch.empty(
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda" (size + 1, head_num, head_dim),
dtype=self.store_dtype,
device=device,
) )
for _ in range(layer_num) for _ in range(layer_num)
] ]
self.v_buffer = [ self.v_buffer = [
torch.empty( torch.empty(
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda" (size + 1, head_num, head_dim),
dtype=self.store_dtype,
device=device,
) )
for _ in range(layer_num) for _ in range(layer_num)
] ]
@@ -210,15 +188,17 @@ class MLATokenToKVPool(BaseTokenToKVPool):
kv_lora_rank: int, kv_lora_rank: int,
qk_rope_head_dim: int, qk_rope_head_dim: int,
layer_num: int, layer_num: int,
device: str,
): ):
super().__init__(size, dtype) super().__init__(size, dtype)
self.kv_lora_rank = kv_lora_rank self.kv_lora_rank = kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self.kv_buffer = [ self.kv_buffer = [
torch.empty( torch.empty(
(size + 1, 1, kv_lora_rank + qk_rope_head_dim), (size + 1, 1, kv_lora_rank + qk_rope_head_dim),
dtype=self.store_dtype, dtype=self.store_dtype,
device="cuda", device=device,
) )
for _ in range(layer_num) for _ in range(layer_num)
] ]

View File

@@ -409,8 +409,11 @@ class ModelRunner:
4096, 4096,
) )
device = "cuda"
self.req_to_token_pool = ReqToTokenPool( self.req_to_token_pool = ReqToTokenPool(
max_num_reqs + 1, self.model_config.context_len + 4, device="cuda" max_num_reqs + 1,
self.model_config.context_len + 4,
device=device,
) )
if ( if (
self.model_config.attention_arch == AttentionArch.MLA self.model_config.attention_arch == AttentionArch.MLA
@@ -422,6 +425,7 @@ class ModelRunner:
kv_lora_rank=self.model_config.kv_lora_rank, kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim, qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
device=device,
) )
else: else:
self.token_to_kv_pool = MHATokenToKVPool( self.token_to_kv_pool = MHATokenToKVPool(
@@ -430,6 +434,7 @@ class ModelRunner:
head_num=self.model_config.get_num_kv_heads(self.tp_size), head_num=self.model_config.get_num_kv_heads(self.tp_size),
head_dim=self.model_config.head_dim, head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers, layer_num=self.model_config.num_hidden_layers,
device=device,
) )
logger.info( logger.info(
f"Memory pool end. " f"Memory pool end. "