fix 3fs zerocopy (#9938)

Co-authored-by: Zhiqiang Xie <xiezhq@stanford.edu>
This commit is contained in:
pansicheng
2025-09-05 04:24:12 +08:00
committed by GitHub
parent b32ab0705e
commit d07304870b
3 changed files with 49 additions and 53 deletions

View File

@@ -500,20 +500,23 @@ class MHATokenToKVPoolHost(HostKVCache):
element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list
def get_buffer_with_hash(self, keys, indices):
def get_buffer_with_hash(self, keys, indices=None):
assert self.layout == "page_first"
assert len(keys) == (len(indices) // self.page_size)
assert indices is None or (len(keys) == (len(indices) // self.page_size))
key_list = []
buf_list = []
for key, i in zip(keys, range(0, len(indices), self.page_size)):
for i in range(len(keys)):
key = keys[i]
key_list.append(f"{key}-k")
buf_list.append(self.k_buffer[i : i + self.page_size])
key_list.append(f"{key}-v")
buf_list.append(self.v_buffer[i : i + self.page_size])
if indices is not None:
index = indices[i * self.page_size]
buf_list.append(self.k_buffer[index : index + self.page_size])
buf_list.append(self.v_buffer[index : index + self.page_size])
return key_list, buf_list
return key_list, buf_list, 2
class MLATokenToKVPoolHost(HostKVCache):
@@ -728,13 +731,15 @@ class MLATokenToKVPoolHost(HostKVCache):
element_size_list = [element_size] * len(key_list)
return key_list, ptr_list, element_size_list
def get_buffer_with_hash(self, keys, indices):
def get_buffer_with_hash(self, keys, indices=None):
assert self.layout == "page_first"
assert len(keys) == (len(indices) // self.page_size)
assert indices is None or (len(keys) == (len(indices) // self.page_size))
buf_list = []
for i in range(0, len(indices), self.page_size):
buf_list.append(self.kv_buffer[i : i + self.page_size])
if indices is not None:
for i in range(len(keys)):
index = indices[i * self.page_size]
buf_list.append(self.kv_buffer[index : index + self.page_size])
return keys, buf_list
return keys, buf_list, 1

View File

@@ -415,22 +415,12 @@ class HiCacheHF3FS(HiCacheStorage):
return result[0] if result else False
def batch_exists(self, keys: List[str]) -> int:
if self.is_page_first_layout and not self.is_mla_model:
query_keys = []
# Compatible with page_first layout's key format, Refer to memory_pool_host.py#get_buffer_with_hash
for key in keys:
query_keys.append(f"{key}-k")
query_keys.append(f"{key}-v")
key_multiplier = 2
else:
query_keys = keys
key_multiplier = 1
results = self.metadata_client.exists(self.rank, keys)
for i in range(len(keys)):
if not results[i]:
return i
exist_result = self.metadata_client.exists(self.rank, query_keys)
for i in range(len(query_keys)):
if not exist_result[i]:
return i // key_multiplier
return len(query_keys) // key_multiplier
return len(keys)
def clear(self) -> bool:
try: