Fix memory pool index error (#616)

This commit is contained in:
Ying Sheng
2024-07-13 16:45:11 -07:00
committed by GitHub
parent 0feca02dd9
commit 5949b1ca0e
4 changed files with 9 additions and 11 deletions

View File

@@ -137,9 +137,6 @@ class RadixAttention(nn.Module):
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
key_buffer[input_metadata.out_cache_loc] = cache_k
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
if input_metadata.out_cache_loc is not None:
key_buffer[input_metadata.out_cache_loc] = cache_k
value_buffer[input_metadata.out_cache_loc] = cache_v
else:
raise RuntimeError()
value_buffer[input_metadata.out_cache_loc] = cache_v