Move mem_state update into debug mode (#4525)
This commit is contained in:
@@ -580,13 +580,20 @@ class MemoryStateInt(IntEnum):
|
||||
BACKUP = 4
|
||||
|
||||
|
||||
def synchronized(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
with self.lock:
|
||||
return func(self, *args, **kwargs)
|
||||
def synchronized(debug_only=False):
|
||||
def _decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if (not debug_only) or self.debug:
|
||||
return func(self, *args, **kwargs)
|
||||
with self.lock:
|
||||
return func(self, *args, **kwargs)
|
||||
else:
|
||||
return True
|
||||
|
||||
return wrapper
|
||||
return wrapper
|
||||
|
||||
return _decorator
|
||||
|
||||
|
||||
class HostKVCache(abc.ABC):
|
||||
@@ -631,13 +638,9 @@ class HostKVCache(abc.ABC):
|
||||
|
||||
self.kv_buffer = self.init_kv_buffer()
|
||||
|
||||
# Initialize memory states and tracking structures.
|
||||
self.mem_state = torch.zeros(
|
||||
(self.size,), dtype=torch.uint8, device=self.device
|
||||
)
|
||||
|
||||
# A lock for synchronized operations on memory allocation and state transitions.
|
||||
self.lock = threading.RLock()
|
||||
self.debug = logger.isEnabledFor(logging.DEBUG)
|
||||
self.clear()
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -664,13 +667,38 @@ class HostKVCache(abc.ABC):
|
||||
def assign_flat_data(self, indices, flat_data):
|
||||
raise NotImplementedError()
|
||||
|
||||
@synchronized
|
||||
@synchronized()
|
||||
def clear(self):
|
||||
self.mem_state.fill_(0)
|
||||
self.can_use_mem_size = self.size
|
||||
# Initialize memory states and tracking structures.
|
||||
self.mem_state = torch.zeros(
|
||||
(self.size,), dtype=torch.uint8, device=self.device
|
||||
)
|
||||
self.free_slots = torch.arange(self.size, dtype=torch.int64)
|
||||
|
||||
@synchronized
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
@synchronized()
|
||||
def alloc(self, need_size: int) -> torch.Tensor:
|
||||
if need_size > self.available_size():
|
||||
return None
|
||||
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
if self.debug:
|
||||
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
||||
|
||||
return select_index
|
||||
|
||||
@synchronized()
|
||||
def free(self, indices: torch.Tensor) -> int:
|
||||
self.free_slots = torch.cat([self.free_slots, indices])
|
||||
if self.debug:
|
||||
self.mem_state[indices] = MemoryStateInt.IDLE
|
||||
return len(indices)
|
||||
|
||||
@synchronized(debug_only=True)
|
||||
def get_state(self, indices: torch.Tensor) -> MemoryStateInt:
|
||||
assert len(indices) > 0, "The indices should not be empty"
|
||||
states = self.mem_state[indices]
|
||||
@@ -679,82 +707,62 @@ class HostKVCache(abc.ABC):
|
||||
).all(), "The memory slots should have the same state {}".format(states)
|
||||
return MemoryStateInt(states[0].item())
|
||||
|
||||
@synchronized
|
||||
def alloc(self, need_size: int) -> torch.Tensor:
|
||||
if need_size > self.can_use_mem_size:
|
||||
return None
|
||||
|
||||
# todo: de-fragementation
|
||||
select_index = self.free_slots[:need_size]
|
||||
self.free_slots = self.free_slots[need_size:]
|
||||
|
||||
self.mem_state[select_index] = MemoryStateInt.RESERVED
|
||||
self.can_use_mem_size -= need_size
|
||||
|
||||
return select_index
|
||||
|
||||
@synchronized
|
||||
@synchronized(debug_only=True)
|
||||
def is_reserved(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.RESERVED
|
||||
|
||||
@synchronized
|
||||
@synchronized(debug_only=True)
|
||||
def is_protected(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized
|
||||
@synchronized(debug_only=True)
|
||||
def is_synced(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.SYNCED
|
||||
|
||||
@synchronized
|
||||
@synchronized(debug_only=True)
|
||||
def is_backup(self, indices: torch.Tensor) -> bool:
|
||||
return self.get_state(indices) == MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized
|
||||
@synchronized(debug_only=True)
|
||||
def update_backup(self, indices: torch.Tensor):
|
||||
assert self.is_synced(indices), (
|
||||
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
if not self.is_synced(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be in SYNCED state before turning into BACKUP. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.BACKUP
|
||||
|
||||
@synchronized
|
||||
@synchronized(debug_only=True)
|
||||
def update_synced(self, indices: torch.Tensor):
|
||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
||||
|
||||
@synchronized
|
||||
@synchronized(debug_only=True)
|
||||
def protect_write(self, indices: torch.Tensor):
|
||||
assert self.is_reserved(indices), (
|
||||
f"The host memory slots should be RESERVED before write operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
if not self.is_reserved(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be RESERVED before write operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized
|
||||
@synchronized(debug_only=True)
|
||||
def protect_load(self, indices: torch.Tensor):
|
||||
assert self.is_backup(indices), (
|
||||
f"The host memory slots should be in BACKUP state before load operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
if not self.is_backup(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be in BACKUP state before load operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.PROTECTED
|
||||
|
||||
@synchronized
|
||||
@synchronized(debug_only=True)
|
||||
def complete_io(self, indices: torch.Tensor):
|
||||
assert self.is_protected(indices), (
|
||||
f"The host memory slots should be PROTECTED during I/O operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
if not self.is_protected(indices):
|
||||
raise ValueError(
|
||||
f"The host memory slots should be PROTECTED during I/O operations. "
|
||||
f"Current state: {self.get_state(indices)}"
|
||||
)
|
||||
self.mem_state[indices] = MemoryStateInt.SYNCED
|
||||
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
@synchronized
|
||||
def free(self, indices: torch.Tensor) -> int:
|
||||
self.mem_state[indices] = MemoryStateInt.IDLE
|
||||
self.free_slots = torch.cat([self.free_slots, indices])
|
||||
self.can_use_mem_size += len(indices)
|
||||
return len(indices)
|
||||
|
||||
|
||||
class MHATokenToKVPoolHost(HostKVCache):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user