[Feature] Support fp8 e5m2 kv cache with flashinfer (#1204)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -203,7 +203,6 @@ class RadixAttention(nn.Module):
|
||||
return self.decode_forward(q, k, v, input_metadata)
|
||||
|
||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||
k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
||||
v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
||||
k_cache[input_metadata.out_cache_loc] = cache_k
|
||||
v_cache[input_metadata.out_cache_loc] = cache_v
|
||||
input_metadata.token_to_kv_pool.set_kv_buffer(
|
||||
self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user