Intoduce cpu tensor as metadata to avoid blocking gpu kernel launch (#10720)
Co-authored-by: hnyls2002 <lsyincs@gmail.com>
This commit is contained in:
@@ -76,6 +76,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
req_pool_indices, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
||||
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
||||
self.orig_seq_lens = torch.tensor(
|
||||
seq_lens, dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
@@ -900,6 +900,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
token_type_ids: torch.Tensor = None # shape: [b], int64
|
||||
req_pool_indices: torch.Tensor = None # shape: [b], int64
|
||||
seq_lens: torch.Tensor = None # shape: [b], int64
|
||||
seq_lens_cpu: torch.Tensor = None # shape: [b], int64
|
||||
# The output locations of the KV cache
|
||||
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
||||
output_ids: torch.Tensor = None # shape: [b], int64
|
||||
@@ -1055,7 +1056,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
def alloc_paged_token_slots_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefix_lens_cpu: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
backup_state: bool = False,
|
||||
@@ -1063,7 +1066,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
# Over estimate the number of tokens: assume each request needs a new page.
|
||||
num_tokens = (
|
||||
extend_num_tokens
|
||||
+ len(seq_lens) * self.token_to_kv_pool_allocator.page_size
|
||||
+ len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
|
||||
)
|
||||
self._evict_tree_cache_if_needed(num_tokens)
|
||||
|
||||
@@ -1071,7 +1074,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
||||
prefix_lens, seq_lens, last_loc, extend_num_tokens
|
||||
prefix_lens,
|
||||
prefix_lens_cpu,
|
||||
seq_lens,
|
||||
seq_lens_cpu,
|
||||
last_loc,
|
||||
extend_num_tokens,
|
||||
)
|
||||
if out_cache_loc is None:
|
||||
error_msg = (
|
||||
@@ -1090,6 +1098,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
def alloc_paged_token_slots_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
backup_state: bool = False,
|
||||
):
|
||||
@@ -1100,7 +1109,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
if backup_state:
|
||||
state = self.token_to_kv_pool_allocator.backup_state()
|
||||
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(seq_lens, last_loc)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
|
||||
seq_lens, seq_lens_cpu, last_loc
|
||||
)
|
||||
if out_cache_loc is None:
|
||||
error_msg = (
|
||||
f"Decode out of memory. Try to lower your batch size.\n"
|
||||
@@ -1169,6 +1180,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
||||
|
||||
if not decoder_out_cache_loc:
|
||||
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
|
||||
@@ -1217,12 +1229,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
seq_lens_cpu_tensor = torch.tensor(seq_lens, dtype=torch.int64)
|
||||
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
prefix_lens_tensor = torch.tensor(
|
||||
prefix_lens, dtype=torch.int64, device=self.device
|
||||
)
|
||||
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
|
||||
|
||||
token_type_ids_tensor = None
|
||||
if len(token_type_ids) > 0:
|
||||
@@ -1349,13 +1363,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
prefix_lens_tensor,
|
||||
)
|
||||
out_cache_loc = self.alloc_paged_token_slots_extend(
|
||||
prefix_lens_tensor, seq_lens_tensor, last_loc, extend_num_tokens
|
||||
prefix_lens_tensor,
|
||||
prefix_lens_cpu_tensor,
|
||||
seq_lens_tensor,
|
||||
seq_lens_cpu_tensor,
|
||||
last_loc,
|
||||
extend_num_tokens,
|
||||
)
|
||||
|
||||
# Set fields
|
||||
self.input_ids = input_ids_tensor
|
||||
self.req_pool_indices = req_pool_indices_tensor
|
||||
self.seq_lens = seq_lens_tensor
|
||||
self.seq_lens_cpu = seq_lens_cpu_tensor
|
||||
self.orig_seq_lens = orig_seq_lens_tensor
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.input_embeds = (
|
||||
@@ -1498,7 +1518,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while first_iter or (
|
||||
not self.check_decode_mem(selected_indices=sorted_indices)
|
||||
@@ -1548,7 +1567,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
||||
req = self.reqs[idx]
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
seq_lens_cpu = self.seq_lens_cpu.numpy()
|
||||
|
||||
if server_args.disaggregation_mode == "decode":
|
||||
req.offload_kv_cache(
|
||||
@@ -1592,6 +1611,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.forward_mode = ForwardMode.IDLE
|
||||
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
|
||||
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
@@ -1651,10 +1671,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
if self.enable_overlap:
|
||||
# Do not use in-place operations in the overlap mode
|
||||
self.seq_lens = self.seq_lens + 1
|
||||
self.seq_lens_cpu = self.seq_lens_cpu + 1
|
||||
self.orig_seq_lens = self.orig_seq_lens + 1
|
||||
else:
|
||||
# A faster in-place version
|
||||
self.seq_lens.add_(1)
|
||||
self.seq_lens_cpu.add_(1)
|
||||
self.orig_seq_lens.add_(1)
|
||||
self.seq_lens_sum += bs
|
||||
|
||||
@@ -1673,7 +1695,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.req_pool_indices, self.seq_lens - 2
|
||||
]
|
||||
self.out_cache_loc = self.alloc_paged_token_slots_decode(
|
||||
self.seq_lens, last_loc
|
||||
self.seq_lens, self.seq_lens_cpu, last_loc
|
||||
)
|
||||
|
||||
self.req_to_token_pool.write(
|
||||
@@ -1719,6 +1741,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
||||
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
||||
self.seq_lens = self.seq_lens[keep_indices_device]
|
||||
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
||||
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum = self.seq_lens.sum().item()
|
||||
@@ -1759,6 +1782,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
[self.req_pool_indices, other.req_pool_indices]
|
||||
)
|
||||
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
|
||||
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
|
||||
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
|
||||
self.out_cache_loc = None
|
||||
self.seq_lens_sum += other.seq_lens_sum
|
||||
@@ -1802,9 +1826,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.sampling_info.grammars = None
|
||||
|
||||
seq_lens_cpu = (
|
||||
seq_lens_cpu_cache
|
||||
if seq_lens_cpu_cache is not None
|
||||
else self.seq_lens.cpu()
|
||||
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
|
||||
)
|
||||
|
||||
global bid
|
||||
|
||||
@@ -27,7 +27,7 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
||||
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
||||
from sglang.srt.utils import get_bool_env_var, get_num_new_pages, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache
|
||||
@@ -294,7 +294,6 @@ def alloc_extend_kernel(
|
||||
last_loc_ptr,
|
||||
free_page_ptr,
|
||||
out_indices,
|
||||
ret_values,
|
||||
bs_upper: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
max_num_extend_tokens: tl.constexpr,
|
||||
@@ -323,13 +322,6 @@ def alloc_extend_kernel(
|
||||
sum_num_new_pages = tl.sum(num_new_pages)
|
||||
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
||||
|
||||
# Return value
|
||||
if pid == tl.num_programs(0) - 1:
|
||||
merged_value = (sum_num_new_pages.to(tl.int64)) << 32 | sum_extend_lens.to(
|
||||
tl.int64
|
||||
)
|
||||
tl.store(ret_values, merged_value)
|
||||
|
||||
# Part 1: fill the old partial page
|
||||
last_loc = tl.load(last_loc_ptr + pid)
|
||||
num_part1 = (
|
||||
@@ -381,7 +373,6 @@ def alloc_decode_kernel(
|
||||
last_loc_ptr,
|
||||
free_page_ptr,
|
||||
out_indices,
|
||||
ret_values,
|
||||
bs_upper: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
):
|
||||
@@ -404,10 +395,6 @@ def alloc_decode_kernel(
|
||||
sum_num_new_pages = tl.sum(num_new_pages)
|
||||
new_page_start_loc = sum_num_new_pages - num_page_start_loc_self
|
||||
|
||||
# Return value
|
||||
if pid == tl.num_programs(0) - 1:
|
||||
tl.store(ret_values, sum_num_new_pages)
|
||||
|
||||
if num_page_start_loc_self == 0:
|
||||
last_loc = tl.load(last_loc_ptr + pid)
|
||||
tl.store(out_indices + pid, last_loc + 1)
|
||||
@@ -438,7 +425,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
||||
self.num_pages = size // page_size
|
||||
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
||||
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
||||
self.seen_max_num_extend_tokens_next_power_of_2 = 1
|
||||
self.clear()
|
||||
|
||||
@@ -468,7 +454,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
def alloc_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefix_lens_cpu: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
):
|
||||
@@ -497,7 +485,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.ret_values,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
self.seen_max_num_extend_tokens_next_power_of_2,
|
||||
@@ -506,8 +493,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
merged_value = self.ret_values.item()
|
||||
num_new_pages = merged_value >> 32
|
||||
num_new_pages = get_num_new_pages(prefix_lens_cpu, seq_lens_cpu, self.page_size)
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
@@ -517,6 +503,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
def alloc_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
):
|
||||
if self.debug_mode:
|
||||
@@ -534,7 +521,6 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
last_loc,
|
||||
self.free_pages,
|
||||
out_indices,
|
||||
self.ret_values,
|
||||
next_power_of_2(bs),
|
||||
self.page_size,
|
||||
)
|
||||
@@ -542,7 +528,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
||||
if self.debug_mode:
|
||||
assert len(torch.unique(out_indices)) == len(out_indices)
|
||||
|
||||
num_new_pages = self.ret_values.item()
|
||||
num_new_pages = get_num_new_pages(
|
||||
seq_lens_cpu - 1, seq_lens_cpu, self.page_size, decode=True
|
||||
)
|
||||
if num_new_pages > len(self.free_pages):
|
||||
return None
|
||||
|
||||
|
||||
@@ -69,7 +69,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
def alloc_extend(
|
||||
self,
|
||||
prefix_lens: torch.Tensor,
|
||||
prefix_lens_cpu: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
extend_num_tokens: int,
|
||||
):
|
||||
@@ -80,8 +82,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
|
||||
num_new_pages = (
|
||||
(
|
||||
(seq_lens + self.page_size - 1) // self.page_size
|
||||
- (prefix_lens + self.page_size - 1) // self.page_size
|
||||
(seq_lens_cpu + self.page_size - 1) // self.page_size
|
||||
- (prefix_lens_cpu + self.page_size - 1) // self.page_size
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
@@ -115,6 +117,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
def alloc_decode(
|
||||
self,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_cpu: torch.Tensor,
|
||||
last_loc: torch.Tensor,
|
||||
):
|
||||
if self.debug_mode:
|
||||
@@ -123,7 +126,8 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
||||
)
|
||||
|
||||
need_new_pages = (seq_lens % self.page_size == 1).int()
|
||||
num_new_pages = need_new_pages.sum().item()
|
||||
need_new_pages_cpu = (seq_lens_cpu % self.page_size == 1).int()
|
||||
num_new_pages = need_new_pages_cpu.sum().item()
|
||||
|
||||
if num_new_pages > len(self.free_pages):
|
||||
self.merge_and_sort_free()
|
||||
|
||||
@@ -104,14 +104,21 @@ class EagleVerifyInput(SpecInput):
|
||||
end_offset = batch.seq_lens + self.draft_token_num
|
||||
else:
|
||||
prefix_lens = batch.seq_lens
|
||||
prefix_lens_cpu = batch.seq_lens_cpu
|
||||
end_offset = prefix_lens + self.draft_token_num
|
||||
end_offset_cpu = prefix_lens_cpu + self.draft_token_num
|
||||
last_loc = get_last_loc(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
)
|
||||
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
|
||||
prefix_lens, end_offset, last_loc, len(batch.input_ids)
|
||||
prefix_lens,
|
||||
prefix_lens_cpu,
|
||||
end_offset,
|
||||
end_offset_cpu,
|
||||
last_loc,
|
||||
len(batch.input_ids),
|
||||
)
|
||||
self.last_loc = last_loc
|
||||
|
||||
@@ -380,6 +387,8 @@ class EagleVerifyInput(SpecInput):
|
||||
verified_id = predict[accept_index]
|
||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||
evict_mask[accept_index] = False
|
||||
accept_length_cpu = accept_length.cpu()
|
||||
accept_length_list = accept_length_cpu.tolist()
|
||||
|
||||
if page_size == 1:
|
||||
# TODO: boolean array index leads to a device sync. Remove it.
|
||||
@@ -456,13 +465,15 @@ class EagleVerifyInput(SpecInput):
|
||||
else:
|
||||
batch.out_cache_loc = tgt_cache_loc
|
||||
batch.seq_lens.add_(accept_length + 1)
|
||||
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
|
||||
|
||||
draft_input = EagleDraftInput(
|
||||
hidden_states=batch.spec_info.hidden_states[accept_index],
|
||||
verified_id=verified_id,
|
||||
accept_length=accept_length,
|
||||
accept_length_cpu=accept_length.tolist(),
|
||||
accept_length_cpu=accept_length_list,
|
||||
seq_lens_for_draft_extend=batch.seq_lens,
|
||||
seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu,
|
||||
req_pool_indices_for_draft_extend=batch.req_pool_indices,
|
||||
)
|
||||
|
||||
@@ -485,15 +496,15 @@ class EagleVerifyInput(SpecInput):
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
batch.seq_lens.add_(accept_length + 1)
|
||||
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
|
||||
|
||||
accept_length_cpu = accept_length.tolist()
|
||||
if len(unfinished_accept_index) > 0:
|
||||
unfinished_accept_index = torch.cat(unfinished_accept_index)
|
||||
unfinished_index_device = torch.tensor(
|
||||
unfinished_index, dtype=torch.int64, device=predict.device
|
||||
)
|
||||
draft_input_accept_length_cpu = [
|
||||
accept_length_cpu[i] for i in unfinished_index
|
||||
accept_length_list[i] for i in unfinished_index
|
||||
]
|
||||
if page_size == 1 or self.topk == 1:
|
||||
batch.out_cache_loc = batch.out_cache_loc[unfinished_accept_index]
|
||||
@@ -508,6 +519,7 @@ class EagleVerifyInput(SpecInput):
|
||||
unfinished_index_device,
|
||||
batch.seq_lens,
|
||||
)
|
||||
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
|
||||
filter_finished_cache_loc_kernel[(bs,)](
|
||||
batch.out_cache_loc,
|
||||
tgt_cache_loc,
|
||||
@@ -525,6 +537,7 @@ class EagleVerifyInput(SpecInput):
|
||||
accept_length_cpu=draft_input_accept_length_cpu,
|
||||
accept_length=accept_length[unfinished_index_device],
|
||||
seq_lens_for_draft_extend=batch.seq_lens[unfinished_index_device],
|
||||
seq_lens_for_draft_extend_cpu=batch.seq_lens_cpu[unfinished_index],
|
||||
req_pool_indices_for_draft_extend=batch.req_pool_indices[
|
||||
unfinished_index_device
|
||||
],
|
||||
@@ -542,7 +555,7 @@ class EagleVerifyInput(SpecInput):
|
||||
draft_input=draft_input,
|
||||
logits_output=logits_output,
|
||||
verified_id=verified_id,
|
||||
accept_length_per_req_cpu=accept_length_cpu,
|
||||
accept_length_per_req_cpu=accept_length_list,
|
||||
accepted_indices=accept_index,
|
||||
)
|
||||
|
||||
@@ -575,6 +588,7 @@ class EagleDraftInput(SpecInput):
|
||||
# Inputs for draft extend
|
||||
# shape: (b,)
|
||||
seq_lens_for_draft_extend: torch.Tensor = None
|
||||
seq_lens_for_draft_extend_cpu: torch.Tensor = None
|
||||
req_pool_indices_for_draft_extend: torch.Tensor = None
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -631,6 +645,7 @@ class EagleDraftInput(SpecInput):
|
||||
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
|
||||
batch.extend_num_tokens = sum(batch.extend_lens)
|
||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||
batch.seq_lens_cpu = batch.spec_info.seq_lens_for_draft_extend_cpu
|
||||
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
||||
batch.return_logprob = False
|
||||
batch.return_hidden_states = False
|
||||
|
||||
@@ -543,6 +543,8 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.seq_lens,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
prefix_lens_cpu = batch.seq_lens_cpu
|
||||
seq_lens_cpu = batch.seq_lens_cpu + self.speculative_num_steps
|
||||
extend_num_tokens = num_seqs * self.speculative_num_steps
|
||||
else:
|
||||
# In this case, the last partial page needs to be duplicated.
|
||||
@@ -578,14 +580,23 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
# TODO(lmzheng): remove this device sync
|
||||
extend_num_tokens = torch.sum(self.extend_lens).item()
|
||||
prefix_lens_cpu = batch.seq_lens_cpu
|
||||
last_page_lens = prefix_lens_cpu % self.page_size
|
||||
num_new_pages_per_topk = (
|
||||
last_page_lens + self.speculative_num_steps + self.page_size - 1
|
||||
) // self.page_size
|
||||
seq_lens_cpu = (
|
||||
prefix_lens_cpu // self.page_size * self.page_size
|
||||
+ num_new_pages_per_topk * (self.page_size * self.topk)
|
||||
)
|
||||
extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
|
||||
|
||||
out_cache_loc, token_to_kv_pool_state_backup = (
|
||||
batch.alloc_paged_token_slots_extend(
|
||||
prefix_lens,
|
||||
prefix_lens_cpu,
|
||||
seq_lens,
|
||||
seq_lens_cpu,
|
||||
last_loc,
|
||||
extend_num_tokens,
|
||||
backup_state=True,
|
||||
@@ -1003,6 +1014,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
assert isinstance(batch.spec_info, EagleDraftInput)
|
||||
# Backup fields that will be modified in-place
|
||||
seq_lens_backup = batch.seq_lens.clone()
|
||||
seq_lens_cpu_backup = batch.seq_lens_cpu.clone()
|
||||
req_pool_indices_backup = batch.req_pool_indices
|
||||
accept_length_backup = batch.spec_info.accept_length
|
||||
return_logprob_backup = batch.return_logprob
|
||||
@@ -1081,6 +1093,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
|
||||
)
|
||||
batch.seq_lens = seq_lens_backup
|
||||
batch.seq_lens_cpu = seq_lens_cpu_backup
|
||||
batch.req_pool_indices = req_pool_indices_backup
|
||||
batch.spec_info.accept_length = accept_length_backup
|
||||
batch.return_logprob = return_logprob_backup
|
||||
|
||||
@@ -77,6 +77,7 @@ class NgramVerifyInput(SpecInput):
|
||||
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
||||
end_offset = batch.seq_lens + self.draft_token_num
|
||||
else:
|
||||
# TODO(lsyin): add prefix lens cpu here to support page size > 1
|
||||
prefix_lens = batch.seq_lens
|
||||
end_offset = prefix_lens + self.draft_token_num
|
||||
last_loc = get_last_loc(
|
||||
@@ -405,10 +406,13 @@ class NgramVerifyInput(SpecInput):
|
||||
self._fill_requests(batch, logits_output)
|
||||
self._free_cache(batch, page_size)
|
||||
|
||||
batch.seq_lens.add_(self.accept_length + 1)
|
||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||
accept_length_cpu = self.accept_length.cpu()
|
||||
num_accepted_tokens = accept_length_cpu.sum().item()
|
||||
|
||||
return logits_output, self.verified_id, self.accept_length.sum().item()
|
||||
batch.seq_lens.add_(self.accept_length + 1)
|
||||
batch.seq_lens_cpu.add_(accept_length_cpu + 1)
|
||||
|
||||
return logits_output, self.verified_id, num_accepted_tokens
|
||||
|
||||
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
|
||||
pass
|
||||
|
||||
@@ -3250,6 +3250,30 @@ def get_extend_input_len_swa_limit(
|
||||
return page_size + 2 * max(sliding_window_size, chunked_prefill_size)
|
||||
|
||||
|
||||
def get_num_new_pages(
|
||||
prefix_lens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
page_size: int,
|
||||
decode: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Get the number of new pages for the given prefix and sequence lengths. We use cpu tensors to avoid blocking kernel launch.
|
||||
"""
|
||||
cpu_device = torch.device("cpu")
|
||||
assert prefix_lens.device == cpu_device
|
||||
assert seq_lens.device == cpu_device
|
||||
num_pages_after = (seq_lens + page_size - 1) // page_size
|
||||
num_pages_before = (prefix_lens + page_size - 1) // page_size
|
||||
num_new_pages = num_pages_after - num_pages_before
|
||||
extend_lens = seq_lens - prefix_lens
|
||||
sum_num_new_pages = torch.sum(num_new_pages).to(torch.int64)
|
||||
if decode:
|
||||
return sum_num_new_pages.item()
|
||||
merged_value = (sum_num_new_pages) << 32 | torch.sum(extend_lens).to(torch.int64)
|
||||
|
||||
return merged_value.item() >> 32
|
||||
|
||||
|
||||
class CachedKernel:
|
||||
"""
|
||||
Wrapper that allows kernel[grid](...) syntax with caching based on a key function.
|
||||
|
||||
Reference in New Issue
Block a user