From 4d253057000eaf5a4b9a8cc9932e884c6ecdfca0 Mon Sep 17 00:00:00 2001 From: Zhiqiang Xie Date: Sun, 23 Mar 2025 00:52:27 -0700 Subject: [PATCH] Move mem_state update into debug mode (#4525) --- python/sglang/srt/mem_cache/memory_pool.py | 136 +++++++++++---------- 1 file changed, 72 insertions(+), 64 deletions(-) diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 2b0f72be8..a882f7451 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -580,13 +580,20 @@ class MemoryStateInt(IntEnum): BACKUP = 4 -def synchronized(func): - @wraps(func) - def wrapper(self, *args, **kwargs): - with self.lock: - return func(self, *args, **kwargs) +def synchronized(debug_only=False): + def _decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if (not debug_only) or self.debug: + return func(self, *args, **kwargs) + with self.lock: + return func(self, *args, **kwargs) + else: + return True - return wrapper + return wrapper + + return _decorator class HostKVCache(abc.ABC): @@ -631,13 +638,9 @@ class HostKVCache(abc.ABC): self.kv_buffer = self.init_kv_buffer() - # Initialize memory states and tracking structures. - self.mem_state = torch.zeros( - (self.size,), dtype=torch.uint8, device=self.device - ) - # A lock for synchronized operations on memory allocation and state transitions. self.lock = threading.RLock() + self.debug = logger.isEnabledFor(logging.DEBUG) self.clear() @abc.abstractmethod @@ -664,13 +667,38 @@ class HostKVCache(abc.ABC): def assign_flat_data(self, indices, flat_data): raise NotImplementedError() - @synchronized + @synchronized() def clear(self): - self.mem_state.fill_(0) - self.can_use_mem_size = self.size + # Initialize memory states and tracking structures. + self.mem_state = torch.zeros( + (self.size,), dtype=torch.uint8, device=self.device + ) self.free_slots = torch.arange(self.size, dtype=torch.int64) - @synchronized + def available_size(self): + return len(self.free_slots) + + @synchronized() + def alloc(self, need_size: int) -> torch.Tensor: + if need_size > self.available_size(): + return None + + select_index = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + + if self.debug: + self.mem_state[select_index] = MemoryStateInt.RESERVED + + return select_index + + @synchronized() + def free(self, indices: torch.Tensor) -> int: + self.free_slots = torch.cat([self.free_slots, indices]) + if self.debug: + self.mem_state[indices] = MemoryStateInt.IDLE + return len(indices) + + @synchronized(debug_only=True) def get_state(self, indices: torch.Tensor) -> MemoryStateInt: assert len(indices) > 0, "The indices should not be empty" states = self.mem_state[indices] @@ -679,82 +707,62 @@ class HostKVCache(abc.ABC): ).all(), "The memory slots should have the same state {}".format(states) return MemoryStateInt(states[0].item()) - @synchronized - def alloc(self, need_size: int) -> torch.Tensor: - if need_size > self.can_use_mem_size: - return None - - # todo: de-fragementation - select_index = self.free_slots[:need_size] - self.free_slots = self.free_slots[need_size:] - - self.mem_state[select_index] = MemoryStateInt.RESERVED - self.can_use_mem_size -= need_size - - return select_index - - @synchronized + @synchronized(debug_only=True) def is_reserved(self, indices: torch.Tensor) -> bool: return self.get_state(indices) == MemoryStateInt.RESERVED - @synchronized + @synchronized(debug_only=True) def is_protected(self, indices: torch.Tensor) -> bool: return self.get_state(indices) == MemoryStateInt.PROTECTED - @synchronized + @synchronized(debug_only=True) def is_synced(self, indices: torch.Tensor) -> bool: return self.get_state(indices) == MemoryStateInt.SYNCED - @synchronized + @synchronized(debug_only=True) def is_backup(self, indices: torch.Tensor) -> bool: return self.get_state(indices) == MemoryStateInt.BACKUP - @synchronized + @synchronized(debug_only=True) def update_backup(self, indices: torch.Tensor): - assert self.is_synced(indices), ( - f"The host memory slots should be in SYNCED state before turning into BACKUP. " - f"Current state: {self.get_state(indices)}" - ) + if not self.is_synced(indices): + raise ValueError( + f"The host memory slots should be in SYNCED state before turning into BACKUP. " + f"Current state: {self.get_state(indices)}" + ) self.mem_state[indices] = MemoryStateInt.BACKUP - @synchronized + @synchronized(debug_only=True) def update_synced(self, indices: torch.Tensor): self.mem_state[indices] = MemoryStateInt.SYNCED - @synchronized + @synchronized(debug_only=True) def protect_write(self, indices: torch.Tensor): - assert self.is_reserved(indices), ( - f"The host memory slots should be RESERVED before write operations. " - f"Current state: {self.get_state(indices)}" - ) + if not self.is_reserved(indices): + raise ValueError( + f"The host memory slots should be RESERVED before write operations. " + f"Current state: {self.get_state(indices)}" + ) self.mem_state[indices] = MemoryStateInt.PROTECTED - @synchronized + @synchronized(debug_only=True) def protect_load(self, indices: torch.Tensor): - assert self.is_backup(indices), ( - f"The host memory slots should be in BACKUP state before load operations. " - f"Current state: {self.get_state(indices)}" - ) + if not self.is_backup(indices): + raise ValueError( + f"The host memory slots should be in BACKUP state before load operations. " + f"Current state: {self.get_state(indices)}" + ) self.mem_state[indices] = MemoryStateInt.PROTECTED - @synchronized + @synchronized(debug_only=True) def complete_io(self, indices: torch.Tensor): - assert self.is_protected(indices), ( - f"The host memory slots should be PROTECTED during I/O operations. " - f"Current state: {self.get_state(indices)}" - ) + if not self.is_protected(indices): + raise ValueError( + f"The host memory slots should be PROTECTED during I/O operations. " + f"Current state: {self.get_state(indices)}" + ) self.mem_state[indices] = MemoryStateInt.SYNCED - def available_size(self): - return len(self.free_slots) - - @synchronized - def free(self, indices: torch.Tensor) -> int: - self.mem_state[indices] = MemoryStateInt.IDLE - self.free_slots = torch.cat([self.free_slots, indices]) - self.can_use_mem_size += len(indices) - return len(indices) - class MHATokenToKVPoolHost(HostKVCache): def __init__(