Detokenize incrementally when streaming (#653)

This commit is contained in:
Liangsheng Yin
2024-07-18 17:57:40 -07:00
committed by GitHub
parent 21ba3a88a1
commit a9ef49c12c
5 changed files with 101 additions and 33 deletions

View File

@@ -136,7 +136,33 @@ class RadixAttention(nn.Module):
return self.decode_forward(q, k, v, input_metadata)
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
key_buffer[input_metadata.out_cache_loc] = cache_k
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
value_buffer[input_metadata.out_cache_loc] = cache_v
kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
_store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
try:
@torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
def _store_kv_cache(
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
cache_loc: torch.Tensor,
) -> None:
kv_cache[cache_loc, 0] = k
kv_cache[cache_loc, 1] = v
@_store_kv_cache.register_fake
def _(k, v, kv_cache, cache_loc):
pass
except:
def _store_kv_cache(
k: torch.Tensor,
v: torch.Tensor,
kv_cache: torch.Tensor,
cache_loc: torch.Tensor,
) -> None:
kv_cache[cache_loc, 0] = k
kv_cache[cache_loc, 1] = v