Fix memory pool index error (#616)
This commit is contained in:
@@ -137,9 +137,6 @@ class RadixAttention(nn.Module):
|
||||
|
||||
def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata):
|
||||
key_buffer = input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id)
|
||||
key_buffer[input_metadata.out_cache_loc] = cache_k
|
||||
value_buffer = input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id)
|
||||
if input_metadata.out_cache_loc is not None:
|
||||
key_buffer[input_metadata.out_cache_loc] = cache_k
|
||||
value_buffer[input_metadata.out_cache_loc] = cache_v
|
||||
else:
|
||||
raise RuntimeError()
|
||||
value_buffer[input_metadata.out_cache_loc] = cache_v
|
||||
|
||||
@@ -132,7 +132,8 @@ class CudaGraphRunner:
|
||||
index = bisect.bisect_left(self.batch_size_list, raw_bs)
|
||||
bs = self.batch_size_list[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(1)
|
||||
self.seq_lens.zero_()
|
||||
self.position_ids_offsets.fill_(1)
|
||||
self.out_cache_loc.zero_()
|
||||
|
||||
# Common inputs
|
||||
|
||||
@@ -315,7 +315,7 @@ class ModelTpServer:
|
||||
|
||||
def get_new_fill_batch(self) -> Optional[Batch]:
|
||||
running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
|
||||
if running_bs > self.max_running_requests:
|
||||
if running_bs >= self.max_running_requests:
|
||||
return
|
||||
|
||||
# Compute matched prefix length
|
||||
@@ -393,7 +393,7 @@ class ModelTpServer:
|
||||
else:
|
||||
break
|
||||
|
||||
if running_bs + len(can_run_list) > self.max_running_requests:
|
||||
if running_bs + len(can_run_list) >= self.max_running_requests:
|
||||
break
|
||||
|
||||
if len(can_run_list) == 0:
|
||||
|
||||
@@ -46,7 +46,7 @@ class TokenToKVPool:
|
||||
|
||||
# [size, key/value, head_num, head_dim] for each layer
|
||||
self.kv_data = [
|
||||
torch.empty((size, 2, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
torch.empty((size + 1, 2, head_num, head_dim), dtype=dtype, device="cuda")
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user