refactor: Extract repeated member variables in KVCache subclasses to base class. (#6323)
This commit is contained in:
@@ -94,6 +94,33 @@ class ReqToTokenPool:
|
||||
|
||||
|
||||
class KVCache(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
page_size: int,
|
||||
dtype: torch.dtype,
|
||||
layer_num: int,
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
start_layer: Optional[int] = None,
|
||||
end_layer: Optional[int] = None,
|
||||
):
|
||||
self.size = size
|
||||
self.page_size = page_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.layer_num = layer_num
|
||||
self.start_layer = start_layer or 0
|
||||
self.end_layer = end_layer or layer_num - 1
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
@@ -217,25 +244,20 @@ class MHATokenToKVPool(KVCache):
|
||||
start_layer: Optional[int] = None,
|
||||
end_layer: Optional[int] = None,
|
||||
):
|
||||
self.size = size
|
||||
self.page_size = page_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
super().__init__(
|
||||
size,
|
||||
page_size,
|
||||
dtype,
|
||||
layer_num,
|
||||
device,
|
||||
enable_memory_saver,
|
||||
start_layer,
|
||||
end_layer,
|
||||
)
|
||||
|
||||
self.head_num = head_num
|
||||
self.head_dim = head_dim
|
||||
self.layer_num = layer_num
|
||||
self._create_buffers()
|
||||
self.start_layer = start_layer or 0
|
||||
self.end_layer = end_layer or layer_num - 1
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
self.capture_mode = False
|
||||
@@ -493,26 +515,21 @@ class MLATokenToKVPool(KVCache):
|
||||
start_layer: Optional[int] = None,
|
||||
end_layer: Optional[int] = None,
|
||||
):
|
||||
self.size = size
|
||||
self.page_size = page_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.layer_num = layer_num
|
||||
self.start_layer = start_layer or 0
|
||||
self.end_layer = end_layer or layer_num - 1
|
||||
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
super().__init__(
|
||||
size,
|
||||
page_size,
|
||||
dtype,
|
||||
layer_num,
|
||||
device,
|
||||
enable_memory_saver,
|
||||
start_layer,
|
||||
end_layer,
|
||||
)
|
||||
|
||||
with memory_saver_adapter.region():
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
|
||||
with self.memory_saver_adapter.region():
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.zeros(
|
||||
@@ -636,20 +653,18 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
start_layer: Optional[int] = None,
|
||||
end_layer: Optional[int] = None,
|
||||
):
|
||||
self.size = size
|
||||
self.page_size = page_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
super().__init__(
|
||||
size,
|
||||
page_size,
|
||||
dtype,
|
||||
layer_num,
|
||||
device,
|
||||
enable_memory_saver,
|
||||
start_layer,
|
||||
end_layer,
|
||||
)
|
||||
|
||||
with memory_saver_adapter.region():
|
||||
with self.memory_saver_adapter.region():
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
@@ -672,9 +687,6 @@ class DoubleSparseTokenToKVPool(KVCache):
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
self.start_layer = start_layer or 0
|
||||
self.end_layer = end_layer or layer_num - 1
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
return self.k_buffer[layer_id - self.start_layer]
|
||||
|
||||
@@ -742,7 +754,7 @@ class HostKVCache(abc.ABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
device_pool: KVCache,
|
||||
host_to_device_ratio: float,
|
||||
host_size: int,
|
||||
pin_memory: bool,
|
||||
@@ -914,6 +926,8 @@ class HostKVCache(abc.ABC):
|
||||
|
||||
|
||||
class MHATokenToKVPoolHost(HostKVCache):
|
||||
device_pool: MHATokenToKVPool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
@@ -997,6 +1011,8 @@ class MHATokenToKVPoolHost(HostKVCache):
|
||||
|
||||
|
||||
class MLATokenToKVPoolHost(HostKVCache):
|
||||
device_pool: MLATokenToKVPool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MLATokenToKVPool,
|
||||
|
||||
Reference in New Issue
Block a user