Intoduce cpu tensor as metadata to avoid blocking gpu kernel launch (#10720)
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
This commit is contained in:
@@ -900,6 +900,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
token_type_ids: torch.Tensor = None # shape: [b], int64
|
||||
req_pool_indices: torch.Tensor = None # shape: [b], int64
|
||||
seq_lens: torch.Tensor = None # shape: [b], int64
|
||||
seq_lens_cpu: torch.Tensor = None # shape: [b], int64
|
||||
# The output locations of the KV cache
|
||||
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
||||
output_ids: torch.Tensor = None # shape: [b], int64
|
||||
@@ -1055,7 +1056,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
def alloc_paged_token_slots_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefix_lens_cpu: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
backup_state: bool = False,
|
||||
@@ -1063,7 +1066,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# Over estimate the number of tokens: assume each request needs a new page.
|
||||
num_tokens = (
|
||||
extend_num_tokens
|
||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
+ len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
|
||||
@@ -1071,7 +1074,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
||||
prefix_lens, seq_lens, last_loc, extend_num_tokens
|
||||
prefix_lens,
|
||||
prefix_lens_cpu,
|
||||
seq_lens,
|
||||
seq_lens_cpu,
|
||||
last_loc,
|
||||
extend_num_tokens,
|
||||
)
|
||||
if out_cache_loc is None:
|
||||
error_msg = (
|
||||
@@ -1090,6 +1098,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
def alloc_paged_token_slots_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
backup_state: bool = False,
|
||||
):
|
||||
@@ -1100,7 +1109,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
if backup_state:
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
|
||||
seq_lens, seq_lens_cpu, last_loc
|
||||
)
|
||||
if out_cache_loc is None:
|
||||
error_msg = (
|
||||
f"Decode out of memory. Try to lower your batch size.\n"
|
||||
@@ -1169,6 +1180,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
||||
|
||||
if not decoder_out_cache_loc:
|
||||
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
|
||||
@@ -1217,12 +1229,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64)
|
||||
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
prefix_lens_tensor = torch.tensor(
|
||||
prefix_lens, dtype=torch.int64, device=self.device
|
||||
)
|
||||
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
|
||||
|
||||
token_type_ids_tensor = None
|
||||
if len(token_type_ids) > 0:
|
||||
@@ -1349,13 +1363,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
prefix_lens_tensor,
|
||||
)
|
||||
out_cache_loc = self.alloc_paged_token_slots_extend(
|
||||
prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
|
||||
prefix_lens_tensor,
|
||||
prefix_lens_cpu_tensor,
|
||||
seq_lens_tensor,
|
||||
seq_lens_cpu_tensor,
|
||||
last_loc,
|
||||
extend_num_tokens,
|
||||
)
|
||||
|
||||
# Set fields
|
||||
self.input_ids = input_ids_tensor
|
||||
self.req_pool_indices = req_pool_indices_tensor
|
||||
self.seq_lens = seq_lens_tensor
|
||||
self.seq_lens_cpu = seq_lens_cpu_tensor
|
||||
self.orig_seq_lens = orig_seq_lens_tensor
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.input_embeds = (
|
||||
@@ -1498,7 +1518,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while first_iter or (
|
||||
not self.check_decode_mem(selected_indices=sorted_indices)
|
||||
@@ -1548,7 +1567,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
||||
req = self.reqs[idx]
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
seq_lens_cpu = self.seq_lens_cpu.numpy()
|
||||
|
||||
if server_args.disaggregation_mode == "decode":
|
||||
req.offload_kv_cache(
|
||||
@@ -1592,6 +1611,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.forward_mode = ForwardMode.IDLE
|
||||
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
|
||||
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
@@ -1651,10 +1671,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
if self.enable_overlap:
|
||||
# Do not use in-place operations in the overlap mode
|
||||
self.seq_lens = self.seq_lens + 1
|
||||
self.seq_lens_cpu = self.seq_lens_cpu + 1
|
||||
self.orig_seq_lens = self.orig_seq_lens + 1
|
||||
else:
|
||||
# A faster in-place version
|
||||
self.seq_lens.add_(1)
|
||||
self.seq_lens_cpu.add_(1)
|
||||
self.orig_seq_lens.add_(1)
|
||||
self.seq_lens_sum += bs
|
||||
|
||||
@@ -1673,7 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.req_pool_indices, self.seq_lens - 2
|
||||
]
|
||||
self.out_cache_loc = self.alloc_paged_token_slots_decode(
|
||||
self.seq_lens, last_loc
|
||||
self.seq_lens, self.seq_lens_cpu, last_loc
|
||||
)
|
||||
|
||||
self.req_to_token_pool.write(
|
||||
@@ -1719,6 +1741,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
||||
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
||||
self.seq_lens = self.seq_lens[keep_indices_device]
|
||||
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
||||
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||
@@ -1759,6 +1782,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
|
||||
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
|
||||
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum += other.seq_lens_sum
|
||||
@@ -1802,9 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.sampling_info.grammars = None
|
||||
|
||||
seq_lens_cpu = (
|
||||
seq_lens_cpu_cache
|
||||
if seq_lens_cpu_cache is not None
|
||||
else self.seq_lens.cpu()
|
||||
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
|
||||
)
|
||||
|
||||
global bid
|
||||
|
||||
Reference in New Issue
Block a user