Decouple kv (#679)
This commit is contained in:
@@ -99,7 +99,7 @@ class RadixAttention(nn.Module):
|
||||
else:
|
||||
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
causal=False,
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
@@ -119,7 +119,7 @@ class RadixAttention(nn.Module):
|
||||
|
||||
o = input_metadata.flashinfer_decode_wrapper.forward(
|
||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
||||
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||
sm_scale=self.scaling,
|
||||
logits_soft_cap=self.logit_cap,
|
||||
)
|
||||
@@ -136,33 +136,7 @@ class RadixAttention(nn.Module):
|
||||
return self.decode_forward(q, k, v, input_metadata)
|
||||
|
||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||
kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
|
||||
_store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
|
||||
def _store_kv_cache(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
) -> None:
|
||||
kv_cache[cache_loc, 0] = k
|
||||
kv_cache[cache_loc, 1] = v
|
||||
|
||||
@_store_kv_cache.register_fake
|
||||
def _(k, v, kv_cache, cache_loc):
|
||||
pass
|
||||
|
||||
except:
|
||||
|
||||
def _store_kv_cache(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
cache_loc: torch.Tensor,
|
||||
) -> None:
|
||||
kv_cache[cache_loc, 0] = k
|
||||
kv_cache[cache_loc, 1] = v
|
||||
k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
||||
v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
||||
k_cache[input_metadata.out_cache_loc] = cache_k
|
||||
v_cache[input_metadata.out_cache_loc] = cache_v
|
||||
|
||||
@@ -57,9 +57,13 @@ class TokenToKVPool:
|
||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||
|
||||
# [size, key/value, head_num, head_dim] for each layer
|
||||
self.kv_data = [
|
||||
torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
# [size, head_num, head_dim] for each layer
|
||||
self.k_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
@@ -71,10 +75,13 @@ class TokenToKVPool:
|
||||
self.clear()
|
||||
|
||||
def get_key_buffer(self, layer_id: int):
|
||||
return self.kv_data[layer_id][:, 0]
|
||||
return self.k_buffer[layer_id]
|
||||
|
||||
def get_value_buffer(self, layer_id: int):
|
||||
return self.kv_data[layer_id][:, 1]
|
||||
return self.v_buffer[layer_id]
|
||||
|
||||
def get_kv_buffer(self, layer_id: int):
|
||||
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
||||
|
||||
def available_size(self):
|
||||
return self.can_use_mem_size + len(self.prefetch_buffer)
|
||||
|
||||
@@ -182,7 +182,7 @@ def launch_server(
|
||||
if not server_args.disable_flashinfer:
|
||||
assert_pkg_version(
|
||||
"flashinfer",
|
||||
"0.1.0",
|
||||
"0.1.1",
|
||||
"Please uninstall the old version and "
|
||||
"reinstall the latest version by following the instructions "
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
|
||||
Reference in New Issue
Block a user