Decouple kv (#679)
This commit is contained in:
@@ -99,7 +99,7 @@ class RadixAttention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||||
causal=False,
|
causal=False,
|
||||||
sm_scale=self.scaling,
|
sm_scale=self.scaling,
|
||||||
logits_soft_cap=self.logit_cap,
|
logits_soft_cap=self.logit_cap,
|
||||||
@@ -119,7 +119,7 @@ class RadixAttention(nn.Module):
|
|||||||
|
|
||||||
o = input_metadata.flashinfer_decode_wrapper.forward(
|
o = input_metadata.flashinfer_decode_wrapper.forward(
|
||||||
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
|
||||||
input_metadata.token_to_kv_pool.kv_data[self.layer_id],
|
input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
|
||||||
sm_scale=self.scaling,
|
sm_scale=self.scaling,
|
||||||
logits_soft_cap=self.logit_cap,
|
logits_soft_cap=self.logit_cap,
|
||||||
)
|
)
|
||||||
@@ -136,33 +136,7 @@ class RadixAttention(nn.Module):
|
|||||||
return self.decode_forward(q, k, v, input_metadata)
|
return self.decode_forward(q, k, v, input_metadata)
|
||||||
|
|
||||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||||
kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id]
|
k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
||||||
_store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc)
|
v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
||||||
|
k_cache[input_metadata.out_cache_loc] = cache_k
|
||||||
|
v_cache[input_metadata.out_cache_loc] = cache_v
|
||||||
try:
|
|
||||||
|
|
||||||
@torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"})
|
|
||||||
def _store_kv_cache(
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
cache_loc: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
kv_cache[cache_loc, 0] = k
|
|
||||||
kv_cache[cache_loc, 1] = v
|
|
||||||
|
|
||||||
@_store_kv_cache.register_fake
|
|
||||||
def _(k, v, kv_cache, cache_loc):
|
|
||||||
pass
|
|
||||||
|
|
||||||
except:
|
|
||||||
|
|
||||||
def _store_kv_cache(
|
|
||||||
k: torch.Tensor,
|
|
||||||
v: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
cache_loc: torch.Tensor,
|
|
||||||
) -> None:
|
|
||||||
kv_cache[cache_loc, 0] = k
|
|
||||||
kv_cache[cache_loc, 1] = v
|
|
||||||
|
|||||||
@@ -57,9 +57,13 @@ class TokenToKVPool:
|
|||||||
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
# We also add one slot. This slot is used for writing dummy output from padded tokens.
|
||||||
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda")
|
||||||
|
|
||||||
# [size, key/value, head_num, head_dim] for each layer
|
# [size, head_num, head_dim] for each layer
|
||||||
self.kv_data = [
|
self.k_buffer = [
|
||||||
torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
|
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
||||||
|
for _ in range(layer_num)
|
||||||
|
]
|
||||||
|
self.v_buffer = [
|
||||||
|
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda")
|
||||||
for _ in range(layer_num)
|
for _ in range(layer_num)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -71,10 +75,13 @@ class TokenToKVPool:
|
|||||||
self.clear()
|
self.clear()
|
||||||
|
|
||||||
def get_key_buffer(self, layer_id: int):
|
def get_key_buffer(self, layer_id: int):
|
||||||
return self.kv_data[layer_id][:, 0]
|
return self.k_buffer[layer_id]
|
||||||
|
|
||||||
def get_value_buffer(self, layer_id: int):
|
def get_value_buffer(self, layer_id: int):
|
||||||
return self.kv_data[layer_id][:, 1]
|
return self.v_buffer[layer_id]
|
||||||
|
|
||||||
|
def get_kv_buffer(self, layer_id: int):
|
||||||
|
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
||||||
|
|
||||||
def available_size(self):
|
def available_size(self):
|
||||||
return self.can_use_mem_size + len(self.prefetch_buffer)
|
return self.can_use_mem_size + len(self.prefetch_buffer)
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ def launch_server(
|
|||||||
if not server_args.disable_flashinfer:
|
if not server_args.disable_flashinfer:
|
||||||
assert_pkg_version(
|
assert_pkg_version(
|
||||||
"flashinfer",
|
"flashinfer",
|
||||||
"0.1.0",
|
"0.1.1",
|
||||||
"Please uninstall the old version and "
|
"Please uninstall the old version and "
|
||||||
"reinstall the latest version by following the instructions "
|
"reinstall the latest version by following the instructions "
|
||||||
"at https://docs.flashinfer.ai/installation.html.",
|
"at https://docs.flashinfer.ai/installation.html.",
|
||||||
|
|||||||
Reference in New Issue
Block a user