Introduce future indices (#11301)
This commit is contained in:
@@ -1,3 +1,6 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
@@ -13,6 +16,12 @@ def _resolve_future_token_ids(input_ids, future_token_ids_map):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FutureIndices:
|
||||||
|
indices: torch.Tensor
|
||||||
|
interval: Optional[slice] = None
|
||||||
|
|
||||||
|
|
||||||
class FutureMap:
|
class FutureMap:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -30,23 +39,17 @@ class FutureMap:
|
|||||||
(self.future_buffer_len,), dtype=torch.int64, device=self.device
|
(self.future_buffer_len,), dtype=torch.int64, device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_ct(self, bs: int) -> int:
|
def alloc_future_indices(self, bs: int) -> FutureIndices:
|
||||||
"""Update the circular buffer pointer and return the current pointer."""
|
"""Update the circular buffer pointer and allocate future indices."""
|
||||||
cur_future_ct = self.future_ct
|
cur_future_ct = self.future_ct
|
||||||
self.future_ct = (cur_future_ct + bs) % self.future_limit
|
self.future_ct = (cur_future_ct + bs) % self.future_limit
|
||||||
return cur_future_ct
|
start = cur_future_ct + 1
|
||||||
|
end = cur_future_ct + 1 + bs
|
||||||
|
indices = torch.arange(start, end, dtype=torch.int64, device=self.device)
|
||||||
|
return FutureIndices(indices=indices, interval=slice(start, end))
|
||||||
|
|
||||||
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
|
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
|
||||||
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
|
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
|
||||||
|
|
||||||
def update_next_future(self, future_ct: int, bs: int):
|
def store_to_map(self, future_indices: FutureIndices, next_token_ids: torch.Tensor):
|
||||||
return torch.arange(
|
self.token_ids_buf[future_indices.interval] = next_token_ids
|
||||||
-(future_ct + 1),
|
|
||||||
-(future_ct + 1 + bs),
|
|
||||||
-1,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
|
|
||||||
self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids
|
|
||||||
|
|||||||
@@ -114,7 +114,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.mm_utils import init_embedding_cache
|
from sglang.srt.managers.mm_utils import init_embedding_cache
|
||||||
from sglang.srt.managers.overlap_utils import FutureMap
|
from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
FINISH_ABORT,
|
FINISH_ABORT,
|
||||||
ModelWorkerBatch,
|
ModelWorkerBatch,
|
||||||
@@ -217,7 +217,7 @@ class GenerationBatchResult:
|
|||||||
copy_done: Optional[torch.cuda.Event] = None
|
copy_done: Optional[torch.cuda.Event] = None
|
||||||
delay_sample_launch: bool = False
|
delay_sample_launch: bool = False
|
||||||
forward_batch: Optional[ForwardBatch] = None
|
forward_batch: Optional[ForwardBatch] = None
|
||||||
future_map_ct: Optional[int] = None
|
future_indices: Optional[FutureIndices] = None
|
||||||
|
|
||||||
def copy_to_cpu(self, return_logprob: bool = False):
|
def copy_to_cpu(self, return_logprob: bool = False):
|
||||||
"""Copy tensors to CPU in overlap scheduling.
|
"""Copy tensors to CPU in overlap scheduling.
|
||||||
@@ -2092,7 +2092,7 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
bs = len(model_worker_batch.seq_lens)
|
bs = len(model_worker_batch.seq_lens)
|
||||||
cur_future_map_ct = self.future_map.update_ct(bs)
|
future_indices = self.future_map.alloc_future_indices(bs)
|
||||||
|
|
||||||
with self.forward_stream_ctx:
|
with self.forward_stream_ctx:
|
||||||
self.forward_stream.wait_stream(self.default_stream)
|
self.forward_stream.wait_stream(self.default_stream)
|
||||||
@@ -2108,22 +2108,19 @@ class Scheduler(
|
|||||||
).Event()
|
).Event()
|
||||||
if not model_worker_batch.delay_sample_launch:
|
if not model_worker_batch.delay_sample_launch:
|
||||||
self.future_map.store_to_map(
|
self.future_map.store_to_map(
|
||||||
cur_future_map_ct, bs, batch_result.next_token_ids
|
future_indices, batch_result.next_token_ids
|
||||||
)
|
)
|
||||||
batch_result.copy_to_cpu()
|
batch_result.copy_to_cpu()
|
||||||
else:
|
else:
|
||||||
batch_result.future_map_ct = cur_future_map_ct
|
batch_result.future_indices = future_indices
|
||||||
|
|
||||||
# FIXME(lsyin): move this assignment elsewhere
|
# FIXME(lsyin): move this assignment elsewhere
|
||||||
maybe_future_next_token_ids = self.future_map.update_next_future(
|
maybe_future_next_token_ids = -future_indices.indices
|
||||||
cur_future_map_ct, bs
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
batch_result = self.model_worker.forward_batch_generation(
|
batch_result = self.model_worker.forward_batch_generation(
|
||||||
batch_or_worker_batch
|
batch_or_worker_batch
|
||||||
)
|
)
|
||||||
maybe_future_next_token_ids = batch_result.next_token_ids
|
maybe_future_next_token_ids = batch_result.next_token_ids
|
||||||
copy_done = None
|
|
||||||
|
|
||||||
if not self.spec_algorithm.is_none():
|
if not self.spec_algorithm.is_none():
|
||||||
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
||||||
@@ -2182,8 +2179,8 @@ class Scheduler(
|
|||||||
tmp_result.logits_output,
|
tmp_result.logits_output,
|
||||||
tmp_result.forward_batch,
|
tmp_result.forward_batch,
|
||||||
)
|
)
|
||||||
ct, bs = tmp_result.future_map_ct, len(tmp_batch.reqs)
|
future_indices = tmp_result.future_indices
|
||||||
self.future_map.store_to_map(ct, bs, tmp_result.next_token_ids)
|
self.future_map.store_to_map(future_indices, tmp_result.next_token_ids)
|
||||||
tmp_result.copy_to_cpu()
|
tmp_result.copy_to_cpu()
|
||||||
self.result_queue.appendleft((tmp_batch, tmp_result))
|
self.result_queue.appendleft((tmp_batch, tmp_result))
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user