[overlap-spec] support page size > 1 (#11772)
This commit is contained in:
@@ -14,6 +14,7 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchTokenIDOutput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
||||
from sglang.srt.utils.common import ceil_div
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.scheduler import (
|
||||
@@ -258,22 +259,22 @@ class SchedulerOutputProcessorMixin:
|
||||
|
||||
if self.enable_overlap and req.finished():
|
||||
indices_to_free = None
|
||||
if self.page_size == 1:
|
||||
if batch.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
if batch.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
|
||||
end_p = allocate_lens_list[i]
|
||||
start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE
|
||||
indices_to_free = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx
|
||||
][start_p:end_p]
|
||||
else:
|
||||
end_p = allocate_lens_list[i]
|
||||
start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE
|
||||
if self.page_size > 1:
|
||||
start_p = ceil_div(start_p, self.page_size) * self.page_size
|
||||
|
||||
indices_to_free = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx
|
||||
][start_p:end_p]
|
||||
|
||||
else:
|
||||
if self.page_size == 1:
|
||||
# Free the one extra delayed token
|
||||
indices_to_free = batch.out_cache_loc[i : i + 1]
|
||||
else:
|
||||
if batch.spec_algorithm.is_eagle():
|
||||
# TODO(spec-v2): support eagle with page_size > 1
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
if (
|
||||
len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
@@ -299,6 +300,10 @@ class SchedulerOutputProcessorMixin:
|
||||
# 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration.
|
||||
start_p = batch.seq_lens_cpu[i] + accept_lens_list[i]
|
||||
end_p = allocate_lens_list[i]
|
||||
|
||||
if self.page_size > 1:
|
||||
start_p = ceil_div(start_p, self.page_size) * self.page_size
|
||||
|
||||
indices_to_free = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx
|
||||
][start_p:end_p]
|
||||
|
||||
@@ -10,7 +10,11 @@ import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
|
||||
from sglang.srt.mem_cache.common import alloc_token_slots
|
||||
from sglang.srt.mem_cache.common import (
|
||||
alloc_paged_token_slots_extend,
|
||||
alloc_token_slots,
|
||||
get_last_loc,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
@@ -82,9 +86,31 @@ class EagleDraftInputV2Mixin:
|
||||
# Now seq_lens and allocate_lens are correct
|
||||
batch.maybe_wait_verify_done()
|
||||
|
||||
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
|
||||
num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
|
||||
out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
|
||||
page_size = batch.token_to_kv_pool_allocator.page_size
|
||||
|
||||
if page_size == 1:
|
||||
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
|
||||
num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
|
||||
out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
|
||||
else:
|
||||
last_loc = get_last_loc(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
self.allocate_lens,
|
||||
)
|
||||
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
|
||||
new_allocate_lens_cpu = new_allocate_lens.cpu()
|
||||
allocate_lens_cpu = self.allocate_lens.cpu()
|
||||
extend_num_tokens = sum(new_allocate_lens_cpu - allocate_lens_cpu).item()
|
||||
out_cache_loc = alloc_paged_token_slots_extend(
|
||||
batch.tree_cache,
|
||||
self.allocate_lens,
|
||||
allocate_lens_cpu,
|
||||
new_allocate_lens,
|
||||
new_allocate_lens_cpu,
|
||||
last_loc,
|
||||
extend_num_tokens,
|
||||
)
|
||||
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
|
||||
Reference in New Issue
Block a user