Move status check in the memory pool to CPU (#1557)
This commit is contained in:
@@ -19,6 +19,7 @@ import logging
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -69,56 +70,27 @@ class BaseTokenToKVPool(ABC):
|
|||||||
else:
|
else:
|
||||||
self.store_dtype = dtype
|
self.store_dtype = dtype
|
||||||
|
|
||||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
self.free_slots = None
|
||||||
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
|
||||||
|
|
||||||
# Prefetch buffer
|
|
||||||
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
|
||||||
self.prefetch_chunk_size = 512
|
|
||||||
|
|
||||||
self.can_use_mem_size = self.size
|
|
||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
def available_size(self):
|
def available_size(self):
|
||||||
return self.can_use_mem_size + len(self.prefetch_buffer)
|
return len(self.free_slots)
|
||||||
|
|
||||||
def alloc(self, need_size: int):
|
def alloc(self, need_size: int):
|
||||||
buffer_len = len(self.prefetch_buffer)
|
if need_size > len(self.free_slots):
|
||||||
if need_size <= buffer_len:
|
|
||||||
select_index = self.prefetch_buffer[:need_size]
|
|
||||||
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
|
||||||
return select_index
|
|
||||||
|
|
||||||
addition_size = need_size - buffer_len
|
|
||||||
alloc_size = max(addition_size, self.prefetch_chunk_size)
|
|
||||||
select_index = (
|
|
||||||
torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32)
|
|
||||||
)
|
|
||||||
|
|
||||||
if select_index.shape[0] < addition_size:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
self.mem_state[select_index] = False
|
select_index = self.free_slots[:need_size]
|
||||||
self.can_use_mem_size -= len(select_index)
|
self.free_slots = self.free_slots[need_size:]
|
||||||
|
|
||||||
self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index))
|
return torch.tensor(select_index, dtype=torch.int32, device="cuda")
|
||||||
ret_index = self.prefetch_buffer[:need_size]
|
|
||||||
self.prefetch_buffer = self.prefetch_buffer[need_size:]
|
|
||||||
|
|
||||||
return ret_index
|
|
||||||
|
|
||||||
def free(self, free_index: torch.Tensor):
|
def free(self, free_index: torch.Tensor):
|
||||||
self.mem_state[free_index] = True
|
self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy()))
|
||||||
self.can_use_mem_size += len(free_index)
|
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32)
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
|
self.free_slots = np.arange(1, self.size + 1)
|
||||||
self.mem_state.fill_(True)
|
|
||||||
self.can_use_mem_size = self.size
|
|
||||||
|
|
||||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
|
||||||
self.mem_state[0] = False
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
@@ -152,19 +124,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|||||||
head_num: int,
|
head_num: int,
|
||||||
head_dim: int,
|
head_dim: int,
|
||||||
layer_num: int,
|
layer_num: int,
|
||||||
|
device: str,
|
||||||
):
|
):
|
||||||
super().__init__(size, dtype)
|
super().__init__(size, dtype)
|
||||||
|
|
||||||
# [size, head_num, head_dim] for each layer
|
# [size, head_num, head_dim] for each layer
|
||||||
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
self.k_buffer = [
|
self.k_buffer = [
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
(size + 1, head_num, head_dim),
|
||||||
|
dtype=self.store_dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
self.v_buffer = [
|
self.v_buffer = [
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda"
|
(size + 1, head_num, head_dim),
|
||||||
|
dtype=self.store_dtype,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
@@ -210,15 +188,17 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|||||||
kv_lora_rank: int,
|
kv_lora_rank: int,
|
||||||
qk_rope_head_dim: int,
|
qk_rope_head_dim: int,
|
||||||
layer_num: int,
|
layer_num: int,
|
||||||
|
device: str,
|
||||||
):
|
):
|
||||||
super().__init__(size, dtype)
|
super().__init__(size, dtype)
|
||||||
|
|
||||||
self.kv_lora_rank = kv_lora_rank
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
self.kv_buffer = [
|
self.kv_buffer = [
|
||||||
torch.empty(
|
torch.empty(
|
||||||
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
(size + 1, 1, kv_lora_rank + qk_rope_head_dim),
|
||||||
dtype=self.store_dtype,
|
dtype=self.store_dtype,
|
||||||
device="cuda",
|
device=device,
|
||||||
)
|
)
|
||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -409,8 +409,11 @@ class ModelRunner:
|
|||||||
4096,
|
4096,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
self.req_to_token_pool = ReqToTokenPool(
|
self.req_to_token_pool = ReqToTokenPool(
|
||||||
max_num_reqs + 1, self.model_config.context_len + 4, device="cuda"
|
max_num_reqs + 1,
|
||||||
|
self.model_config.context_len + 4,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
self.model_config.attention_arch == AttentionArch.MLA
|
self.model_config.attention_arch == AttentionArch.MLA
|
||||||
@@ -422,6 +425,7 @@ class ModelRunner:
|
|||||||
kv_lora_rank=self.model_config.kv_lora_rank,
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
||||||
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.token_to_kv_pool = MHATokenToKVPool(
|
self.token_to_kv_pool = MHATokenToKVPool(
|
||||||
@@ -430,6 +434,7 @@ class ModelRunner:
|
|||||||
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
||||||
head_dim=self.model_config.head_dim,
|
head_dim=self.model_config.head_dim,
|
||||||
layer_num=self.model_config.num_hidden_layers,
|
layer_num=self.model_config.num_hidden_layers,
|
||||||
|
device=device,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Memory pool end. "
|
f"Memory pool end. "
|
||||||
|
|||||||
Reference in New Issue
Block a user