support memory_pool_host page first direct layout (#10031)
Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
@@ -532,9 +532,12 @@ class HiCacheController:
|
|||||||
host_indices = host_indices.to(self.device, non_blocking=True)
|
host_indices = host_indices.to(self.device, non_blocking=True)
|
||||||
return host_indices, device_indices
|
return host_indices, device_indices
|
||||||
elif self.io_backend == "direct":
|
elif self.io_backend == "direct":
|
||||||
device_indices = device_indices.cpu()
|
if self.mem_pool_host.layout == "layer_first":
|
||||||
host_indices, idx = host_indices.sort()
|
device_indices = device_indices.cpu()
|
||||||
return host_indices, device_indices.index_select(0, idx)
|
host_indices, idx = host_indices.sort()
|
||||||
|
return host_indices, device_indices.index_select(0, idx)
|
||||||
|
elif self.mem_pool_host.layout == "page_first_direct":
|
||||||
|
return host_indices, device_indices.cpu()
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported io backend")
|
raise ValueError(f"Unsupported io backend")
|
||||||
|
|
||||||
|
|||||||
@@ -16,11 +16,13 @@ _is_xpu = is_xpu()
|
|||||||
if not (_is_npu or _is_xpu):
|
if not (_is_npu or _is_xpu):
|
||||||
from sgl_kernel.kvcacheio import (
|
from sgl_kernel.kvcacheio import (
|
||||||
transfer_kv_all_layer,
|
transfer_kv_all_layer,
|
||||||
|
transfer_kv_all_layer_direct_lf_pf,
|
||||||
transfer_kv_all_layer_lf_pf,
|
transfer_kv_all_layer_lf_pf,
|
||||||
transfer_kv_all_layer_mla,
|
transfer_kv_all_layer_mla,
|
||||||
transfer_kv_all_layer_mla_lf_pf,
|
transfer_kv_all_layer_mla_lf_pf,
|
||||||
transfer_kv_direct,
|
transfer_kv_direct,
|
||||||
transfer_kv_per_layer,
|
transfer_kv_per_layer,
|
||||||
|
transfer_kv_per_layer_direct_pf_lf,
|
||||||
transfer_kv_per_layer_mla,
|
transfer_kv_per_layer_mla,
|
||||||
transfer_kv_per_layer_mla_pf_lf,
|
transfer_kv_per_layer_mla_pf_lf,
|
||||||
transfer_kv_per_layer_pf_lf,
|
transfer_kv_per_layer_pf_lf,
|
||||||
@@ -78,6 +80,7 @@ class HostKVCache(abc.ABC):
|
|||||||
self.size = int(device_pool.size * host_to_device_ratio)
|
self.size = int(device_pool.size * host_to_device_ratio)
|
||||||
# Align the host memory pool size to the page size
|
# Align the host memory pool size to the page size
|
||||||
self.size = self.size - (self.size % self.page_size)
|
self.size = self.size - (self.size % self.page_size)
|
||||||
|
self.page_num = self.size // self.page_size
|
||||||
self.start_layer = device_pool.start_layer
|
self.start_layer = device_pool.start_layer
|
||||||
self.end_layer = device_pool.end_layer
|
self.end_layer = device_pool.end_layer
|
||||||
|
|
||||||
@@ -317,6 +320,15 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
|
||||||
elif self.layout == "page_first":
|
elif self.layout == "page_first":
|
||||||
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
|
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
|
||||||
|
elif self.layout == "page_first_direct":
|
||||||
|
dims = (
|
||||||
|
2,
|
||||||
|
self.page_num,
|
||||||
|
self.layer_num,
|
||||||
|
self.page_size,
|
||||||
|
self.head_num,
|
||||||
|
self.head_dim,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
|
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
|
||||||
@@ -370,19 +382,31 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
elif io_backend == "direct":
|
elif io_backend == "direct":
|
||||||
assert (
|
if self.layout == "layer_first":
|
||||||
self.layout == "layer_first"
|
transfer_kv_direct(
|
||||||
), f"Direct IO backend only supports layer_first layout."
|
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
|
||||||
transfer_kv_direct(
|
dst_layers=[
|
||||||
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
|
device_pool.k_buffer[layer_id],
|
||||||
dst_layers=[
|
device_pool.v_buffer[layer_id],
|
||||||
device_pool.k_buffer[layer_id],
|
],
|
||||||
device_pool.v_buffer[layer_id],
|
src_indices=host_indices,
|
||||||
],
|
dst_indices=device_indices,
|
||||||
src_indices=host_indices,
|
page_size=self.page_size,
|
||||||
dst_indices=device_indices,
|
)
|
||||||
page_size=self.page_size,
|
elif self.layout == "page_first_direct":
|
||||||
)
|
transfer_kv_per_layer_direct_pf_lf(
|
||||||
|
src_ptrs=[self.k_buffer, self.v_buffer],
|
||||||
|
dst_ptrs=[
|
||||||
|
device_pool.k_buffer[layer_id],
|
||||||
|
device_pool.v_buffer[layer_id],
|
||||||
|
],
|
||||||
|
src_indices=host_indices,
|
||||||
|
dst_indices=device_indices,
|
||||||
|
layer_id=layer_id,
|
||||||
|
page_size=self.page_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
||||||
|
|
||||||
@@ -416,16 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
elif io_backend == "direct":
|
elif io_backend == "direct":
|
||||||
assert (
|
if self.layout == "layer_first":
|
||||||
self.layout == "layer_first"
|
transfer_kv_direct(
|
||||||
), f"Direct IO backend only supports layer_first layout."
|
src_layers=device_pool.k_buffer + device_pool.v_buffer,
|
||||||
transfer_kv_direct(
|
dst_layers=self.k_data_refs + self.v_data_refs,
|
||||||
src_layers=device_pool.k_buffer + device_pool.v_buffer,
|
src_indices=device_indices,
|
||||||
dst_layers=self.k_data_refs + self.v_data_refs,
|
dst_indices=host_indices,
|
||||||
src_indices=device_indices,
|
page_size=self.page_size,
|
||||||
dst_indices=host_indices,
|
)
|
||||||
page_size=self.page_size,
|
elif self.layout == "page_first_direct":
|
||||||
)
|
transfer_kv_all_layer_direct_lf_pf(
|
||||||
|
src_ptrs=device_pool.k_buffer + device_pool.v_buffer,
|
||||||
|
dst_ptrs=[self.k_buffer, self.v_buffer],
|
||||||
|
src_indices=device_indices,
|
||||||
|
dst_indices=host_indices,
|
||||||
|
page_size=self.page_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
||||||
|
|
||||||
@@ -580,6 +612,14 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
1,
|
1,
|
||||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
)
|
)
|
||||||
|
elif self.layout == "page_first_direct":
|
||||||
|
dims = (
|
||||||
|
self.page_num,
|
||||||
|
self.layer_num,
|
||||||
|
self.page_size,
|
||||||
|
1,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
self.token_stride_size = (
|
self.token_stride_size = (
|
||||||
@@ -619,16 +659,25 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
elif io_backend == "direct":
|
elif io_backend == "direct":
|
||||||
assert (
|
if self.layout == "layer_first":
|
||||||
self.layout == "layer_first"
|
transfer_kv_direct(
|
||||||
), f"Direct IO backend only supports layer_first layout."
|
src_layers=[self.kv_buffer[layer_id]],
|
||||||
transfer_kv_direct(
|
dst_layers=[device_pool.kv_buffer[layer_id]],
|
||||||
src_layers=[self.kv_buffer[layer_id]],
|
src_indices=host_indices,
|
||||||
dst_layers=[device_pool.kv_buffer[layer_id]],
|
dst_indices=device_indices,
|
||||||
src_indices=host_indices,
|
page_size=self.page_size,
|
||||||
dst_indices=device_indices,
|
)
|
||||||
page_size=self.page_size,
|
elif self.layout == "page_first_direct":
|
||||||
)
|
transfer_kv_per_layer_direct_pf_lf(
|
||||||
|
src_ptrs=[self.kv_buffer],
|
||||||
|
dst_ptrs=[device_pool.kv_buffer[layer_id]],
|
||||||
|
src_indices=host_indices,
|
||||||
|
dst_indices=device_indices,
|
||||||
|
layer_id=layer_id,
|
||||||
|
page_size=self.page_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
|
|
||||||
def backup_from_device_all_layer(
|
def backup_from_device_all_layer(
|
||||||
self, device_pool, host_indices, device_indices, io_backend
|
self, device_pool, host_indices, device_indices, io_backend
|
||||||
@@ -656,16 +705,24 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported layout: {self.layout}")
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
elif io_backend == "direct":
|
elif io_backend == "direct":
|
||||||
assert (
|
if self.layout == "layer_first":
|
||||||
self.layout == "layer_first"
|
transfer_kv_direct(
|
||||||
), f"Direct IO backend only supports layer_first layout."
|
src_layers=device_pool.kv_buffer,
|
||||||
transfer_kv_direct(
|
dst_layers=self.data_refs,
|
||||||
src_layers=device_pool.kv_buffer,
|
src_indices=device_indices,
|
||||||
dst_layers=self.data_refs,
|
dst_indices=host_indices,
|
||||||
src_indices=device_indices,
|
page_size=self.page_size,
|
||||||
dst_indices=host_indices,
|
)
|
||||||
page_size=self.page_size,
|
elif self.layout == "page_first_direct":
|
||||||
)
|
transfer_kv_all_layer_direct_lf_pf(
|
||||||
|
src_ptrs=device_pool.kv_buffer,
|
||||||
|
dst_ptrs=[self.kv_buffer],
|
||||||
|
src_indices=device_indices,
|
||||||
|
dst_indices=host_indices,
|
||||||
|
page_size=self.page_size,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported layout: {self.layout}")
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
raise ValueError(f"Unsupported IO backend: {io_backend}")
|
||||||
|
|
||||||
|
|||||||
@@ -721,6 +721,13 @@ class ServerArgs:
|
|||||||
self.hicache_io_backend = "kernel"
|
self.hicache_io_backend = "kernel"
|
||||||
self.hicache_mem_layout = "page_first"
|
self.hicache_mem_layout = "page_first"
|
||||||
|
|
||||||
|
if self.hicache_mem_layout == "page_first_direct":
|
||||||
|
if self.hicache_io_backend != "direct":
|
||||||
|
self.hicache_io_backend = "direct"
|
||||||
|
logger.warning(
|
||||||
|
"Page first direct layout only support direct io backend"
|
||||||
|
)
|
||||||
|
|
||||||
# Speculative Decoding
|
# Speculative Decoding
|
||||||
if self.speculative_algorithm == "NEXTN":
|
if self.speculative_algorithm == "NEXTN":
|
||||||
# NEXTN shares the same implementation of EAGLE
|
# NEXTN shares the same implementation of EAGLE
|
||||||
@@ -1779,7 +1786,7 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--hicache-mem-layout",
|
"--hicache-mem-layout",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["layer_first", "page_first"],
|
choices=["layer_first", "page_first", "page_first_direct"],
|
||||||
default=ServerArgs.hicache_mem_layout,
|
default=ServerArgs.hicache_mem_layout,
|
||||||
help="The layout of host memory pool for hierarchical cache.",
|
help="The layout of host memory pool for hierarchical cache.",
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user