Lianmin/simplify memory pool (#7202)
This commit is contained in:
@@ -34,7 +34,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import debug_timing, get_compiler_backend, is_cuda
|
||||
from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -182,6 +182,9 @@ class TokenToKVPoolAllocator:
|
||||
def available_size(self):
|
||||
return len(self.free_slots)
|
||||
|
||||
def debug_print(self) -> str:
|
||||
return ""
|
||||
|
||||
def get_kvcache(self):
|
||||
return self._kvcache
|
||||
|
||||
@@ -314,17 +317,25 @@ class MHATokenToKVPool(KVCache):
|
||||
# layer_num x [seq_len, head_num, head_dim]
|
||||
# layer_num x [page_num, page_size, head_num, head_dim]
|
||||
kv_data_ptrs = [
|
||||
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
|
||||
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
|
||||
self.get_key_buffer(i).data_ptr()
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i).data_ptr()
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
]
|
||||
kv_data_lens = [
|
||||
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
||||
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
||||
self.get_key_buffer(i).nbytes
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i).nbytes
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
]
|
||||
kv_item_lens = [
|
||||
self.get_key_buffer(i)[0].nbytes * self.page_size
|
||||
for i in range(self.layer_num)
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i)[0].nbytes * self.page_size
|
||||
for i in range(self.layer_num)
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
@@ -444,36 +455,6 @@ class MHATokenToKVPool(KVCache):
|
||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||
|
||||
|
||||
@torch.compile
|
||||
def fused_downcast(
|
||||
cache_k: torch.Tensor,
|
||||
cache_v: torch.Tensor,
|
||||
k_scale: torch.Tensor,
|
||||
v_scale: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
store_dtype: torch.dtype,
|
||||
max_fp8: float,
|
||||
min_fp8: float,
|
||||
):
|
||||
cache_k = cache_k / k_scale
|
||||
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
|
||||
cache_v = cache_v / v_scale
|
||||
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
|
||||
cache_k = cache_k.to(dtype)
|
||||
cache_v = cache_v.to(dtype)
|
||||
cache_k = cache_k.view(store_dtype)
|
||||
cache_v = cache_v.view(store_dtype)
|
||||
return cache_k, cache_v
|
||||
|
||||
|
||||
# This compiled version is slower in the unit test
|
||||
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
||||
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
||||
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def set_mla_kv_buffer_kernel(
|
||||
kv_buffer_ptr,
|
||||
|
||||
Reference in New Issue
Block a user