Move status check in the memory pool to CPU (#1557)
This commit is contained in:
@@ -19,6 +19,7 @@ import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -69,56 +70,27 @@ class BaseTokenToKVPool(ABC):
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||
|
||||
# Prefetch buffer
|
||||
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
||||
self.prefetch_chunk_size = 512
|
||||
|
||||
self.can_use_mem_size = self.size
|
||||
self.free_slots = None
|
||||
self.clear()
|
||||
|
||||
def available_size(self):
|
||||
return self.can_use_mem_size + len(self.prefetch_buffer)
|
||||
return len(self.free_slots)
|
||||
|
||||
def alloc(self, need_size: int):
|
||||
buffer_len = len(self.prefetch_buffer)
|
||||
if need_size <= buffer_len:
|
||||
select_index = self.prefetch_buffer[:need_size]
|
||||
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
||||
return select_index
|
||||
|
||||
addition_size = need_size - buffer_len
|
||||
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
||||
select_index = (
|
||||
torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
|
||||
)
|
||||
|
||||
if select_index.shape[0] < addition_size:
|
||||
if need_size > len(self.free_slots):
|
||||
return None
|
||||
|
||||
self.mem_state[select_index] = False
|
||||
self.can_use_mem_size -= len(select_index)
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
|
||||
ret_index = self.prefetch_buffer[:need_size]
|
||||
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
||||
|
||||
return ret_index
|
||||
return torch.tensor(select_index, dtype=torch.int32, device="cuda")
|
||||
|
||||
def free(self, free_index: torch.Tensor):
|
||||
self.mem_state[free_index] = True
|
||||
self.can_use_mem_size += len(free_index)
|
||||
self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
|
||||
|
||||
def clear(self):
|
||||
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
||||
|
||||
self.mem_state.fill_(True)
|
||||
self.can_use_mem_size = self.size
|
||||
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.mem_state[0] = False
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.free_slots = np.arange(1, self.size + 1)
|
||||
|
||||
@abstractmethod
|
||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
@@ -152,19 +124,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
head_num: int,
|
||||
head_dim: int,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
):
|
||||
super().__init__(size, dtype)
|
||||
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
||||
(size + 1, head_num, head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
||||
(size + 1, head_num, head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
@@ -210,15 +188,17 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
):
|
||||
super().__init__(size, dtype)
|
||||
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.empty(
|
||||
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
||||
dtype=self.store_dtype,
|
||||
device="cuda",
|
||||
device=device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
@@ -409,8 +409,11 @@ class ModelRunner:
|
||||
4096,
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
max_num_reqs + 1, self.model_config.context_len + 4, device="cuda"
|
||||
max_num_reqs + 1,
|
||||
self.model_config.context_len + 4,
|
||||
device=device,
|
||||
)
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
@@ -422,6 +425,7 @@ class ModelRunner:
|
||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
self.token_to_kv_pool = MHATokenToKVPool(
|
||||
@@ -430,6 +434,7 @@ class ModelRunner:
|
||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||
head_dim=self.model_config.head_dim,
|
||||
layer_num=self.model_config.num_hidden_layers,
|
||||
device=device,
|
||||
)
|
||||
logger.info(
|
||||
f"Memory pool end. "
|
||||
|
||||
Reference in New Issue
Block a user