Introduce future indices (#11301)
This commit is contained in:
@@ -114,7 +114,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import init_embedding_cache
|
||||
from sglang.srt.managers.overlap_utils import FutureMap
|
||||
from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
ModelWorkerBatch,
|
||||
@@ -217,7 +217,7 @@ class GenerationBatchResult:
|
||||
copy_done: Optional[torch.cuda.Event] = None
|
||||
delay_sample_launch: bool = False
|
||||
forward_batch: Optional[ForwardBatch] = None
|
||||
future_map_ct: Optional[int] = None
|
||||
future_indices: Optional[FutureIndices] = None
|
||||
|
||||
def copy_to_cpu(self, return_logprob: bool = False):
|
||||
"""Copy tensors to CPU in overlap scheduling.
|
||||
@@ -2092,7 +2092,7 @@ class Scheduler(
|
||||
)
|
||||
|
||||
bs = len(model_worker_batch.seq_lens)
|
||||
cur_future_map_ct = self.future_map.update_ct(bs)
|
||||
future_indices = self.future_map.alloc_future_indices(bs)
|
||||
|
||||
with self.forward_stream_ctx:
|
||||
self.forward_stream.wait_stream(self.default_stream)
|
||||
@@ -2108,22 +2108,19 @@ class Scheduler(
|
||||
).Event()
|
||||
if not model_worker_batch.delay_sample_launch:
|
||||
self.future_map.store_to_map(
|
||||
cur_future_map_ct, bs, batch_result.next_token_ids
|
||||
future_indices, batch_result.next_token_ids
|
||||
)
|
||||
batch_result.copy_to_cpu()
|
||||
else:
|
||||
batch_result.future_map_ct = cur_future_map_ct
|
||||
batch_result.future_indices = future_indices
|
||||
|
||||
# FIXME(lsyin): move this assignment elsewhere
|
||||
maybe_future_next_token_ids = self.future_map.update_next_future(
|
||||
cur_future_map_ct, bs
|
||||
)
|
||||
maybe_future_next_token_ids = -future_indices.indices
|
||||
else:
|
||||
batch_result = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
)
|
||||
maybe_future_next_token_ids = batch_result.next_token_ids
|
||||
copy_done = None
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
||||
@@ -2182,8 +2179,8 @@ class Scheduler(
|
||||
tmp_result.logits_output,
|
||||
tmp_result.forward_batch,
|
||||
)
|
||||
ct, bs = tmp_result.future_map_ct, len(tmp_batch.reqs)
|
||||
self.future_map.store_to_map(ct, bs, tmp_result.next_token_ids)
|
||||
future_indices = tmp_result.future_indices
|
||||
self.future_map.store_to_map(future_indices, tmp_result.next_token_ids)
|
||||
tmp_result.copy_to_cpu()
|
||||
self.result_queue.appendleft((tmp_batch, tmp_result))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user