Intoduce cpu tensor as metadata to avoid blocking gpu kernel launch (#10720)

Co-authored-by: hnyls2002 <lsyincs@gmail.com>
This commit is contained in:
Zhang Junda
2025-10-02 10:51:25 +08:00
committed by GitHub
parent 609f65ba23
commit 0b2aa8a70c
8 changed files with 115 additions and 44 deletions

View File

@@ -900,6 +900,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
token_type_ids: torch.Tensor = None # shape: [b], int64
req_pool_indices: torch.Tensor = None # shape: [b], int64
seq_lens: torch.Tensor = None # shape: [b], int64
seq_lens_cpu: torch.Tensor = None # shape: [b], int64
# The output locations of the KV cache
out_cache_loc: torch.Tensor = None # shape: [b], int64
output_ids: torch.Tensor = None # shape: [b], int64
@@ -1055,7 +1056,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def alloc_paged_token_slots_extend(
self,
prefix_lens: torch.Tensor,
prefix_lens_cpu: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
extend_num_tokens: int,
backup_state: bool = False,
@@ -1063,7 +1066,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Over estimate the number of tokens: assume each request needs a new page.
num_tokens = (
extend_num_tokens
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
+ len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
)
self._evict_tree_cache_if_needed(num_tokens)
@@ -1071,7 +1074,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
prefix_lens, seq_lens, last_loc, extend_num_tokens
prefix_lens,
prefix_lens_cpu,
seq_lens,
seq_lens_cpu,
last_loc,
extend_num_tokens,
)
if out_cache_loc is None:
error_msg = (
@@ -1090,6 +1098,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def alloc_paged_token_slots_decode(
self,
seq_lens: torch.Tensor,
seq_lens_cpu: torch.Tensor,
last_loc: torch.Tensor,
backup_state: bool = False,
):
@@ -1100,7 +1109,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if backup_state:
state = self.token_to_kv_pool_allocator.backup_state()
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
seq_lens, seq_lens_cpu, last_loc
)
if out_cache_loc is None:
error_msg = (
f"Decode out of memory. Try to lower your batch size.\n"
@@ -1169,6 +1180,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
if not decoder_out_cache_loc:
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
@@ -1217,12 +1229,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
self.device, non_blocking=True
)
seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64)
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
self.device, non_blocking=True
)
prefix_lens_tensor = torch.tensor(
prefix_lens, dtype=torch.int64, device=self.device
)
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
token_type_ids_tensor = None
if len(token_type_ids) > 0:
@@ -1349,13 +1363,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
prefix_lens_tensor,
)
out_cache_loc = self.alloc_paged_token_slots_extend(
prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
prefix_lens_tensor,
prefix_lens_cpu_tensor,
seq_lens_tensor,
seq_lens_cpu_tensor,
last_loc,
extend_num_tokens,
)
# Set fields
self.input_ids = input_ids_tensor
self.req_pool_indices = req_pool_indices_tensor
self.seq_lens = seq_lens_tensor
self.seq_lens_cpu = seq_lens_cpu_tensor
self.orig_seq_lens = orig_seq_lens_tensor
self.out_cache_loc = out_cache_loc
self.input_embeds = (
@@ -1498,7 +1518,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
)
retracted_reqs = []
seq_lens_cpu = self.seq_lens.cpu().numpy()
first_iter = True
while first_iter or (
not self.check_decode_mem(selected_indices=sorted_indices)
@@ -1548,7 +1567,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
req = self.reqs[idx]
seq_lens_cpu = self.seq_lens.cpu().numpy()
seq_lens_cpu = self.seq_lens_cpu.numpy()
if server_args.disaggregation_mode == "decode":
req.offload_kv_cache(
@@ -1592,6 +1611,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.forward_mode = ForwardMode.IDLE
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
@@ -1651,10 +1671,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
if self.enable_overlap:
# Do not use in-place operations in the overlap mode
self.seq_lens = self.seq_lens + 1
self.seq_lens_cpu = self.seq_lens_cpu + 1
self.orig_seq_lens = self.orig_seq_lens + 1
else:
# A faster in-place version
self.seq_lens.add_(1)
self.seq_lens_cpu.add_(1)
self.orig_seq_lens.add_(1)
self.seq_lens_sum += bs
@@ -1673,7 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.req_pool_indices, self.seq_lens - 2
]
self.out_cache_loc = self.alloc_paged_token_slots_decode(
self.seq_lens, last_loc
self.seq_lens, self.seq_lens_cpu, last_loc
)
self.req_to_token_pool.write(
@@ -1719,6 +1741,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
self.seq_lens = self.seq_lens[keep_indices_device]
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
self.out_cache_loc = None
self.seq_lens_sum = self.seq_lens.sum().item()
@@ -1759,6 +1782,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
[self.req_pool_indices, other.req_pool_indices]
)
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
self.out_cache_loc = None
self.seq_lens_sum += other.seq_lens_sum
@@ -1802,9 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
self.sampling_info.grammars = None
seq_lens_cpu = (
seq_lens_cpu_cache
if seq_lens_cpu_cache is not None
else self.seq_lens.cpu()
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
)
global bid