diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index bf2ca9ba2..45e5e02cb 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -99,7 +99,7 @@ class RadixAttention(nn.Module): else: o2, s2 = input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.token_to_kv_pool.kv_data[self.layer_id], + input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), causal=False, sm_scale=self.scaling, logits_soft_cap=self.logit_cap, @@ -119,7 +119,7 @@ class RadixAttention(nn.Module): o = input_metadata.flashinfer_decode_wrapper.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.token_to_kv_pool.kv_data[self.layer_id], + input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), sm_scale=self.scaling, logits_soft_cap=self.logit_cap, ) @@ -136,33 +136,7 @@ class RadixAttention(nn.Module): return self.decode_forward(q, k, v, input_metadata) def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): - kv_cache = input_metadata.token_to_kv_pool.kv_data[self.layer_id] - _store_kv_cache(cache_k, cache_v, kv_cache, input_metadata.out_cache_loc) - - -try: - - @torch.library.custom_op("mylib::store_kv_cache", mutates_args={"kv_cache"}) - def _store_kv_cache( - k: torch.Tensor, - v: torch.Tensor, - kv_cache: torch.Tensor, - cache_loc: torch.Tensor, - ) -> None: - kv_cache[cache_loc, 0] = k - kv_cache[cache_loc, 1] = v - - @_store_kv_cache.register_fake - def _(k, v, kv_cache, cache_loc): - pass - -except: - - def _store_kv_cache( - k: torch.Tensor, - v: torch.Tensor, - kv_cache: torch.Tensor, - cache_loc: torch.Tensor, - ) -> None: - kv_cache[cache_loc, 0] = k - kv_cache[cache_loc, 1] = v + k_cache = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id) + v_cache = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id) + k_cache[input_metadata.out_cache_loc] = cache_k + v_cache[input_metadata.out_cache_loc] = cache_v diff --git a/python/sglang/srt/memory_pool.py b/python/sglang/srt/memory_pool.py index 573771334..a6335797c 100644 --- a/python/sglang/srt/memory_pool.py +++ b/python/sglang/srt/memory_pool.py @@ -57,9 +57,13 @@ class TokenToKVPool: # We also add one slot. This slot is used for writing dummy output from padded tokens. self.mem_state = torch.ones((self.size + 1,), dtype=torch.bool, device="cuda") - # [size, key/value, head_num, head_dim] for each layer - self.kv_data = [ - torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda") + # [size, head_num, head_dim] for each layer + self.k_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") + for _ in range(layer_num) + ] + self.v_buffer = [ + torch.empty((size + 1, head_num, head_dim), dtype=dtype, device="cuda") for _ in range(layer_num) ] @@ -71,10 +75,13 @@ class TokenToKVPool: self.clear() def get_key_buffer(self, layer_id: int): - return self.kv_data[layer_id][:, 0] + return self.k_buffer[layer_id] def get_value_buffer(self, layer_id: int): - return self.kv_data[layer_id][:, 1] + return self.v_buffer[layer_id] + + def get_kv_buffer(self, layer_id: int): + return self.k_buffer[layer_id], self.v_buffer[layer_id] def available_size(self): return self.can_use_mem_size + len(self.prefetch_buffer) diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index ac62f89ae..c65186bf5 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -182,7 +182,7 @@ def launch_server( if not server_args.disable_flashinfer: assert_pkg_version( "flashinfer", - "0.1.0", + "0.1.1", "Please uninstall the old version and " "reinstall the latest version by following the instructions " "at https://docs.flashinfer.ai/installation.html.",