refactor: Extract repeated member variables in KVCache subclasses to base class. (#6323)
This commit is contained in:
@@ -94,6 +94,33 @@ class ReqToTokenPool:
|
|||||||
|
|
||||||
|
|
||||||
class KVCache(abc.ABC):
|
class KVCache(abc.ABC):
|
||||||
|
@abc.abstractmethod
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
size: int,
|
||||||
|
page_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
layer_num: int,
|
||||||
|
device: str,
|
||||||
|
enable_memory_saver: bool,
|
||||||
|
start_layer: Optional[int] = None,
|
||||||
|
end_layer: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.size = size
|
||||||
|
self.page_size = page_size
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||||
|
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||||
|
self.store_dtype = torch.uint8
|
||||||
|
else:
|
||||||
|
self.store_dtype = dtype
|
||||||
|
self.layer_num = layer_num
|
||||||
|
self.start_layer = start_layer or 0
|
||||||
|
self.end_layer = end_layer or layer_num - 1
|
||||||
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
|
enable=enable_memory_saver
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
@@ -217,25 +244,20 @@ class MHATokenToKVPool(KVCache):
|
|||||||
start_layer: Optional[int] = None,
|
start_layer: Optional[int] = None,
|
||||||
end_layer: Optional[int] = None,
|
end_layer: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.size = size
|
super().__init__(
|
||||||
self.page_size = page_size
|
size,
|
||||||
self.dtype = dtype
|
page_size,
|
||||||
self.device = device
|
dtype,
|
||||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
layer_num,
|
||||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
device,
|
||||||
self.store_dtype = torch.uint8
|
enable_memory_saver,
|
||||||
else:
|
start_layer,
|
||||||
self.store_dtype = dtype
|
end_layer,
|
||||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
|
||||||
enable=enable_memory_saver
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.head_num = head_num
|
self.head_num = head_num
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.layer_num = layer_num
|
|
||||||
self._create_buffers()
|
self._create_buffers()
|
||||||
self.start_layer = start_layer or 0
|
|
||||||
self.end_layer = end_layer or layer_num - 1
|
|
||||||
|
|
||||||
self.layer_transfer_counter = None
|
self.layer_transfer_counter = None
|
||||||
self.capture_mode = False
|
self.capture_mode = False
|
||||||
@@ -493,26 +515,21 @@ class MLATokenToKVPool(KVCache):
|
|||||||
start_layer: Optional[int] = None,
|
start_layer: Optional[int] = None,
|
||||||
end_layer: Optional[int] = None,
|
end_layer: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.size = size
|
super().__init__(
|
||||||
self.page_size = page_size
|
size,
|
||||||
self.dtype = dtype
|
page_size,
|
||||||
self.device = device
|
dtype,
|
||||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
layer_num,
|
||||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
device,
|
||||||
self.store_dtype = torch.uint8
|
enable_memory_saver,
|
||||||
else:
|
start_layer,
|
||||||
self.store_dtype = dtype
|
end_layer,
|
||||||
self.kv_lora_rank = kv_lora_rank
|
|
||||||
self.qk_rope_head_dim = qk_rope_head_dim
|
|
||||||
self.layer_num = layer_num
|
|
||||||
self.start_layer = start_layer or 0
|
|
||||||
self.end_layer = end_layer or layer_num - 1
|
|
||||||
|
|
||||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
|
||||||
enable=enable_memory_saver
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with memory_saver_adapter.region():
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
|
||||||
|
with self.memory_saver_adapter.region():
|
||||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||||
self.kv_buffer = [
|
self.kv_buffer = [
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
@@ -636,20 +653,18 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|||||||
start_layer: Optional[int] = None,
|
start_layer: Optional[int] = None,
|
||||||
end_layer: Optional[int] = None,
|
end_layer: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.size = size
|
super().__init__(
|
||||||
self.page_size = page_size
|
size,
|
||||||
self.dtype = dtype
|
page_size,
|
||||||
self.device = device
|
dtype,
|
||||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
layer_num,
|
||||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
device,
|
||||||
self.store_dtype = torch.uint8
|
enable_memory_saver,
|
||||||
else:
|
start_layer,
|
||||||
self.store_dtype = dtype
|
end_layer,
|
||||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
|
||||||
enable=enable_memory_saver
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with memory_saver_adapter.region():
|
with self.memory_saver_adapter.region():
|
||||||
# [size, head_num, head_dim] for each layer
|
# [size, head_num, head_dim] for each layer
|
||||||
self.k_buffer = [
|
self.k_buffer = [
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
@@ -672,9 +687,6 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
|
|
||||||
self.start_layer = start_layer or 0
|
|
||||||
self.end_layer = end_layer or layer_num - 1
|
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
return self.k_buffer[layer_id - self.start_layer]
|
return self.k_buffer[layer_id - self.start_layer]
|
||||||
|
|
||||||
@@ -742,7 +754,7 @@ class HostKVCache(abc.ABC):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device_pool: MHATokenToKVPool,
|
device_pool: KVCache,
|
||||||
host_to_device_ratio: float,
|
host_to_device_ratio: float,
|
||||||
host_size: int,
|
host_size: int,
|
||||||
pin_memory: bool,
|
pin_memory: bool,
|
||||||
@@ -914,6 +926,8 @@ class HostKVCache(abc.ABC):
|
|||||||
|
|
||||||
|
|
||||||
class MHATokenToKVPoolHost(HostKVCache):
|
class MHATokenToKVPoolHost(HostKVCache):
|
||||||
|
device_pool: MHATokenToKVPool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device_pool: MHATokenToKVPool,
|
device_pool: MHATokenToKVPool,
|
||||||
@@ -997,6 +1011,8 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
|
|
||||||
|
|
||||||
class MLATokenToKVPoolHost(HostKVCache):
|
class MLATokenToKVPoolHost(HostKVCache):
|
||||||
|
device_pool: MLATokenToKVPool
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device_pool: MLATokenToKVPool,
|
device_pool: MLATokenToKVPool,
|
||||||
|
|||||||
Reference in New Issue
Block a user