diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7ba163959..e0588c407 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -834,6 +834,8 @@ class Scheduler: next_token_ids = self.resolve_next_token_ids(bid, next_token_ids) + self.token_to_kv_pool.free_group_begin() + # Check finish condition for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): if self.server_args.enable_overlap_schedule and req.finished(): @@ -860,6 +862,8 @@ class Scheduler: self.stream_output(batch.reqs) + self.token_to_kv_pool.free_group_end() + self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: self.print_decode_stats() diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 1e90c4f67..c8afc1572 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -18,7 +18,6 @@ limitations under the License. import logging from typing import List, Tuple, Union -import numpy as np import torch logger = logging.getLogger(__name__) @@ -77,6 +76,8 @@ class BaseTokenToKVPool: self.store_dtype = dtype self.free_slots = None + self.is_not_in_free_group = True + self.free_group = [] self.clear() def available_size(self): @@ -89,14 +90,28 @@ class BaseTokenToKVPool: select_index = self.free_slots[:need_size] self.free_slots = self.free_slots[need_size:] - return torch.tensor(select_index, dtype=torch.int32, device=self.device) + return select_index.to(self.device) def free(self, free_index: torch.Tensor): - self.free_slots = np.concatenate((self.free_slots, free_index.cpu().numpy())) + if self.is_not_in_free_group: + self.free_slots = torch.concat((self.free_slots, free_index.cpu())) + else: + self.free_group.append(free_index) + + def free_group_begin(self): + self.is_not_in_free_group = False + self.free_group = [] + + def free_group_end(self): + self.is_not_in_free_group = True + if self.free_group: + self.free(torch.concat(self.free_group)) def clear(self): # The padded slot 0 is used for writing dummy outputs from padded tokens. - self.free_slots = np.arange(1, self.size + 1) + self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32) + self.is_in_free_group = False + self.free_group = [] def get_key_buffer(self, layer_id: int) -> torch.Tensor: raise NotImplementedError()