[Feature] Support fp8 e5m2 kv cache with flashinfer (#1204)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
Ke Bao
2024-08-26 08:38:11 +08:00
committed by GitHub
parent 61bb223e0f
commit 2c615d120f
5 changed files with 116 additions and 16 deletions

View File

@@ -203,7 +203,6 @@ class RadixAttention(nn.Module):
return self.decode_forward(q, k, v, input_metadata)
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
k_cache[input_metadata.out_cache_loc] = cache_k
v_cache[input_metadata.out_cache_loc] = cache_v
input_metadata.token_to_kv_pool.set_kv_buffer(
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
)