[overlap-spec] support page size > 1 (#11772)
This commit is contained in:
@@ -14,6 +14,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchTokenIDOutput,
|
BatchTokenIDOutput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import BaseFinishReason, Req, ScheduleBatch
|
||||||
|
from sglang.srt.utils.common import ceil_div
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.scheduler import (
|
from sglang.srt.managers.scheduler import (
|
||||||
@@ -258,22 +259,22 @@ class SchedulerOutputProcessorMixin:
|
|||||||
|
|
||||||
if self.enable_overlap and req.finished():
|
if self.enable_overlap and req.finished():
|
||||||
indices_to_free = None
|
indices_to_free = None
|
||||||
if self.page_size == 1:
|
if batch.spec_algorithm.is_eagle():
|
||||||
if batch.spec_algorithm.is_eagle():
|
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
|
||||||
|
|
||||||
end_p = allocate_lens_list[i]
|
end_p = allocate_lens_list[i]
|
||||||
start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE
|
start_p = end_p - EagleDraftInput.ALLOC_LEN_PER_DECODE
|
||||||
indices_to_free = self.req_to_token_pool.req_to_token[
|
if self.page_size > 1:
|
||||||
req.req_pool_idx
|
start_p = ceil_div(start_p, self.page_size) * self.page_size
|
||||||
][start_p:end_p]
|
|
||||||
else:
|
indices_to_free = self.req_to_token_pool.req_to_token[
|
||||||
|
req.req_pool_idx
|
||||||
|
][start_p:end_p]
|
||||||
|
|
||||||
|
else:
|
||||||
|
if self.page_size == 1:
|
||||||
# Free the one extra delayed token
|
# Free the one extra delayed token
|
||||||
indices_to_free = batch.out_cache_loc[i : i + 1]
|
indices_to_free = batch.out_cache_loc[i : i + 1]
|
||||||
else:
|
|
||||||
if batch.spec_algorithm.is_eagle():
|
|
||||||
# TODO(spec-v2): support eagle with page_size > 1
|
|
||||||
raise NotImplementedError()
|
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
len(req.origin_input_ids) + len(req.output_ids) - 1
|
len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||||
@@ -299,6 +300,10 @@ class SchedulerOutputProcessorMixin:
|
|||||||
# 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration.
|
# 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration.
|
||||||
start_p = batch.seq_lens_cpu[i] + accept_lens_list[i]
|
start_p = batch.seq_lens_cpu[i] + accept_lens_list[i]
|
||||||
end_p = allocate_lens_list[i]
|
end_p = allocate_lens_list[i]
|
||||||
|
|
||||||
|
if self.page_size > 1:
|
||||||
|
start_p = ceil_div(start_p, self.page_size) * self.page_size
|
||||||
|
|
||||||
indices_to_free = self.req_to_token_pool.req_to_token[
|
indices_to_free = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx
|
req.req_pool_idx
|
||||||
][start_p:end_p]
|
][start_p:end_p]
|
||||||
|
|||||||
@@ -10,7 +10,11 @@ import triton.language as tl
|
|||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, ScheduleBatch
|
||||||
from sglang.srt.mem_cache.common import alloc_token_slots
|
from sglang.srt.mem_cache.common import (
|
||||||
|
alloc_paged_token_slots_extend,
|
||||||
|
alloc_token_slots,
|
||||||
|
get_last_loc,
|
||||||
|
)
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||||
from sglang.srt.model_executor.forward_batch_info import (
|
from sglang.srt.model_executor.forward_batch_info import (
|
||||||
CaptureHiddenMode,
|
CaptureHiddenMode,
|
||||||
@@ -82,9 +86,31 @@ class EagleDraftInputV2Mixin:
|
|||||||
# Now seq_lens and allocate_lens are correct
|
# Now seq_lens and allocate_lens are correct
|
||||||
batch.maybe_wait_verify_done()
|
batch.maybe_wait_verify_done()
|
||||||
|
|
||||||
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
|
page_size = batch.token_to_kv_pool_allocator.page_size
|
||||||
num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
|
|
||||||
out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
|
if page_size == 1:
|
||||||
|
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
|
||||||
|
num_needed_tokens = (new_allocate_lens - self.allocate_lens).sum().item()
|
||||||
|
out_cache_loc = alloc_token_slots(batch.tree_cache, num_needed_tokens)
|
||||||
|
else:
|
||||||
|
last_loc = get_last_loc(
|
||||||
|
batch.req_to_token_pool.req_to_token,
|
||||||
|
batch.req_pool_indices,
|
||||||
|
self.allocate_lens,
|
||||||
|
)
|
||||||
|
new_allocate_lens = batch.seq_lens + self.ALLOC_LEN_PER_DECODE
|
||||||
|
new_allocate_lens_cpu = new_allocate_lens.cpu()
|
||||||
|
allocate_lens_cpu = self.allocate_lens.cpu()
|
||||||
|
extend_num_tokens = sum(new_allocate_lens_cpu - allocate_lens_cpu).item()
|
||||||
|
out_cache_loc = alloc_paged_token_slots_extend(
|
||||||
|
batch.tree_cache,
|
||||||
|
self.allocate_lens,
|
||||||
|
allocate_lens_cpu,
|
||||||
|
new_allocate_lens,
|
||||||
|
new_allocate_lens_cpu,
|
||||||
|
last_loc,
|
||||||
|
extend_num_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
assign_req_to_token_pool[(bs,)](
|
assign_req_to_token_pool[(bs,)](
|
||||||
batch.req_pool_indices,
|
batch.req_pool_indices,
|
||||||
|
|||||||
Reference in New Issue
Block a user