Lianmin/simplify memory pool (#7202)
This commit is contained in:
@@ -34,7 +34,7 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.utils import debug_timing, get_compiler_backend, is_cuda
|
from sglang.srt.utils import debug_timing, is_cuda, next_power_of_2
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -182,6 +182,9 @@ class TokenToKVPoolAllocator:
|
|||||||
def available_size(self):
|
def available_size(self):
|
||||||
return len(self.free_slots)
|
return len(self.free_slots)
|
||||||
|
|
||||||
|
def debug_print(self) -> str:
|
||||||
|
return ""
|
||||||
|
|
||||||
def get_kvcache(self):
|
def get_kvcache(self):
|
||||||
return self._kvcache
|
return self._kvcache
|
||||||
|
|
||||||
@@ -314,17 +317,25 @@ class MHATokenToKVPool(KVCache):
|
|||||||
# layer_num x [seq_len, head_num, head_dim]
|
# layer_num x [seq_len, head_num, head_dim]
|
||||||
# layer_num x [page_num, page_size, head_num, head_dim]
|
# layer_num x [page_num, page_size, head_num, head_dim]
|
||||||
kv_data_ptrs = [
|
kv_data_ptrs = [
|
||||||
self.get_key_buffer(i).data_ptr() for i in range(self.layer_num)
|
self.get_key_buffer(i).data_ptr()
|
||||||
] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)]
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||||
|
] + [
|
||||||
|
self.get_value_buffer(i).data_ptr()
|
||||||
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||||
|
]
|
||||||
kv_data_lens = [
|
kv_data_lens = [
|
||||||
self.get_key_buffer(i).nbytes for i in range(self.layer_num)
|
self.get_key_buffer(i).nbytes
|
||||||
] + [self.get_value_buffer(i).nbytes for i in range(self.layer_num)]
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||||
|
] + [
|
||||||
|
self.get_value_buffer(i).nbytes
|
||||||
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||||
|
]
|
||||||
kv_item_lens = [
|
kv_item_lens = [
|
||||||
self.get_key_buffer(i)[0].nbytes * self.page_size
|
self.get_key_buffer(i)[0].nbytes * self.page_size
|
||||||
for i in range(self.layer_num)
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||||
] + [
|
] + [
|
||||||
self.get_value_buffer(i)[0].nbytes * self.page_size
|
self.get_value_buffer(i)[0].nbytes * self.page_size
|
||||||
for i in range(self.layer_num)
|
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||||
]
|
]
|
||||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||||
|
|
||||||
@@ -444,36 +455,6 @@ class MHATokenToKVPool(KVCache):
|
|||||||
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
self.v_buffer[layer_id - self.start_layer][loc] = cache_v
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
|
||||||
def fused_downcast(
|
|
||||||
cache_k: torch.Tensor,
|
|
||||||
cache_v: torch.Tensor,
|
|
||||||
k_scale: torch.Tensor,
|
|
||||||
v_scale: torch.Tensor,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
store_dtype: torch.dtype,
|
|
||||||
max_fp8: float,
|
|
||||||
min_fp8: float,
|
|
||||||
):
|
|
||||||
cache_k = cache_k / k_scale
|
|
||||||
cache_k = torch.clamp(cache_k, min_fp8, max_fp8)
|
|
||||||
cache_v = cache_v / v_scale
|
|
||||||
cache_v = torch.clamp(cache_v, min_fp8, max_fp8)
|
|
||||||
cache_k = cache_k.to(dtype)
|
|
||||||
cache_v = cache_v.to(dtype)
|
|
||||||
cache_k = cache_k.view(store_dtype)
|
|
||||||
cache_v = cache_v.view(store_dtype)
|
|
||||||
return cache_k, cache_v
|
|
||||||
|
|
||||||
|
|
||||||
# This compiled version is slower in the unit test
|
|
||||||
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
|
||||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
|
||||||
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
|
||||||
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
|
||||||
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def set_mla_kv_buffer_kernel(
|
def set_mla_kv_buffer_kernel(
|
||||||
kv_buffer_ptr,
|
kv_buffer_ptr,
|
||||||
|
|||||||
Reference in New Issue
Block a user