Detokenize incrementally when streaming (#653)
This commit is contained in:
@@ -136,7 +136,33 @@ class RadixAttention(nn.Module):
|
||||
return self.decode_forward(q, k, v, input_metadata)
|
||||
|
||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
||||
key_buffer[input_metadata.out_cache_loc] = cache_k
|
||||
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
||||
value_buffer[input_metadata.out_cache_loc] = cache_v
|
||||
kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
|
||||
_store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
|
||||
def _store_kv_cache(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
) -> None:
|
||||
kv_cache[cache_loc, 0] = k
|
||||
kv_cache[cache_loc, 1] = v
|
||||
|
||||
@_store_kv_cache.register_fake
|
||||
def _(k, v, kv_cache, cache_loc):
|
||||
pass
|
||||
|
||||
except:
|
||||
|
||||
def _store_kv_cache(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
) -> None:
|
||||
kv_cache[cache_loc, 0] = k
|
||||
kv_cache[cache_loc, 1] = v
|
||||
|
||||
Reference in New Issue
Block a user