diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 7e08007ed..b6dd8dcdd 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -34,7 +34,7 @@ import triton import triton.language as tl from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import debug_timing, get_compiler_backend, is_cuda +from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2 logger = logging.getLogger(__name__) @@ -182,6 +182,9 @@ class TokenToKVPoolAllocator: def available_size(self): return len(self.free_slots) + def debug_print(self) -> str: + return "" + def get_kvcache(self): return self._kvcache @@ -314,17 +317,25 @@ class MHATokenToKVPool(KVCache): # layer_num x [seq_len, head_num, head_dim] # layer_num x [page_num, page_size, head_num, head_dim] kv_data_ptrs = [ - self.get_key_buffer(i).data_ptr() for i in range(self.layer_num) - ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)] + self.get_key_buffer(i).data_ptr() + for i in range(self.start_layer, self.start_layer + self.layer_num) + ] + [ + self.get_value_buffer(i).data_ptr() + for i in range(self.start_layer, self.start_layer + self.layer_num) + ] kv_data_lens = [ - self.get_key_buffer(i).nbytes for i in range(self.layer_num) - ] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)] + self.get_key_buffer(i).nbytes + for i in range(self.start_layer, self.start_layer + self.layer_num) + ] + [ + self.get_value_buffer(i).nbytes + for i in range(self.start_layer, self.start_layer + self.layer_num) + ] kv_item_lens = [ self.get_key_buffer(i)[0].nbytes * self.page_size - for i in range(self.layer_num) + for i in range(self.start_layer, self.start_layer + self.layer_num) ] + [ self.get_value_buffer(i)[0].nbytes * self.page_size - for i in range(self.layer_num) + for i in range(self.start_layer, self.start_layer + self.layer_num) ] return kv_data_ptrs, kv_data_lens, kv_item_lens @@ -444,36 +455,6 @@ class MHATokenToKVPool(KVCache): self.v_buffer[layer_id - self.start_layer][loc] = cache_v -@torch.compile -def fused_downcast( - cache_k: torch.Tensor, - cache_v: torch.Tensor, - k_scale: torch.Tensor, - v_scale: torch.Tensor, - dtype: torch.dtype, - store_dtype: torch.dtype, - max_fp8: float, - min_fp8: float, -): - cache_k = cache_k / k_scale - cache_k = torch.clamp(cache_k, min_fp8, max_fp8) - cache_v = cache_v / v_scale - cache_v = torch.clamp(cache_v, min_fp8, max_fp8) - cache_k = cache_k.to(dtype) - cache_v = cache_v.to(dtype) - cache_k = cache_k.view(store_dtype) - cache_v = cache_v.view(store_dtype) - return cache_k, cache_v - - -# This compiled version is slower in the unit test -# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size -@torch.compile(dynamic=True, backend=get_compiler_backend()) -def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype): - dst_1[loc] = src_1.to(dtype).view(store_dtype) - dst_2[loc] = src_2.to(dtype).view(store_dtype) - - @triton.jit def set_mla_kv_buffer_kernel( kv_buffer_ptr,