diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 8860ce5ce..7dea057d1 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -19,6 +19,7 @@ import logging from abc import ABC, abstractmethod from typing import List, Tuple, Union +import numpy as np import torch logger = logging.getLogger(__name__) @@ -69,56 +70,27 @@ class BaseTokenToKVPool(ABC): else: self.store_dtype = dtype - # We also add one slot. This slot is used for writing dummy output from padded tokens. - self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda") - - # Prefetch buffer - self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32) - self.prefetch_chunk_size = 512 - - self.can_use_mem_size = self.size + self.free_slots = None self.clear() def available_size(self): - return self.can_use_mem_size + len(self.prefetch_buffer) + return len(self.free_slots) def alloc(self, need_size: int): - buffer_len = len(self.prefetch_buffer) - if need_size <= buffer_len: - select_index = self.prefetch_buffer[:need_size] - self.prefetch_buffer = self.prefetch_buffer[need_size:] - return select_index - - addition_size = need_size - buffer_len - alloc_size = max(addition_size, self.prefetch_chunk_size) - select_index = ( - torch.nonzero(self.mem_state).squeeze(1)[:alloc_size].to(torch.int32) - ) - - if select_index.shape[0] < addition_size: + if need_size > len(self.free_slots): return None - self.mem_state[select_index] = False - self.can_use_mem_size -= len(select_index) + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] - self.prefetch_buffer = torch.cat((self.prefetch_buffer, select_index)) - ret_index = self.prefetch_buffer[:need_size] - self.prefetch_buffer = self.prefetch_buffer[need_size:] - - return ret_index + return torch.tensor(select_index, dtype=torch.int32, device="cuda") def free(self, free_index: torch.Tensor): - self.mem_state[free_index] = True - self.can_use_mem_size += len(free_index) + self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy())) def clear(self): - self.prefetch_buffer = torch.empty(0, device="cuda", dtype=torch.int32) - - self.mem_state.fill_(True) - self.can_use_mem_size = self.size - - # We also add one slot. This slot is used for writing dummy output from padded tokens. - self.mem_state[0] = False + # The padded slot 0 is used for writing dummy outputs from padded tokens. + self.free_slots = np.arange(1, self.size + 1) @abstractmethod def get_key_buffer(self, layer_id: int) -> torch.Tensor: @@ -152,19 +124,25 @@ class MHATokenToKVPool(BaseTokenToKVPool): head_num: int, head_dim: int, layer_num: int, + device: str, ): super().__init__(size, dtype) # [size, head_num, head_dim] for each layer + # The padded slot 0 is used for writing dummy outputs from padded tokens. self.k_buffer = [ torch.empty( - (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda" + (size + 1, head_num, head_dim), + dtype=self.store_dtype, + device=device, ) for _ in range(layer_num) ] self.v_buffer = [ torch.empty( - (size + 1, head_num, head_dim), dtype=self.store_dtype, device="cuda" + (size + 1, head_num, head_dim), + dtype=self.store_dtype, + device=device, ) for _ in range(layer_num) ] @@ -210,15 +188,17 @@ class MLATokenToKVPool(BaseTokenToKVPool): kv_lora_rank: int, qk_rope_head_dim: int, layer_num: int, + device: str, ): super().__init__(size, dtype) self.kv_lora_rank = kv_lora_rank + # The padded slot 0 is used for writing dummy outputs from padded tokens. self.kv_buffer = [ torch.empty( (size + 1, 1, kv_lora_rank + qk_rope_head_dim), dtype=self.store_dtype, - device="cuda", + device=device, ) for _ in range(layer_num) ] diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 79d42f546..f596acbe9 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -409,8 +409,11 @@ class ModelRunner: 4096, ) + device = "cuda" self.req_to_token_pool = ReqToTokenPool( - max_num_reqs + 1, self.model_config.context_len + 4, device="cuda" + max_num_reqs + 1, + self.model_config.context_len + 4, + device=device, ) if ( self.model_config.attention_arch == AttentionArch.MLA @@ -422,6 +425,7 @@ class ModelRunner: kv_lora_rank=self.model_config.kv_lora_rank, qk_rope_head_dim=self.model_config.qk_rope_head_dim, layer_num=self.model_config.num_hidden_layers, + device=device, ) else: self.token_to_kv_pool = MHATokenToKVPool( @@ -430,6 +434,7 @@ class ModelRunner: head_num=self.model_config.get_num_kv_heads(self.tp_size), head_dim=self.model_config.head_dim, layer_num=self.model_config.num_hidden_layers, + device=device, ) logger.info( f"Memory pool end. "