Introduce FutureMap (#10715)
This commit is contained in:
53
python/sglang/srt/managers/overlap_utils.py
Normal file
53
python/sglang/srt/managers/overlap_utils.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
|
from sglang.srt.utils import get_compiler_backend
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
|
def _resolve_future_token_ids(input_ids, future_token_ids_map):
|
||||||
|
input_ids[:] = torch.where(
|
||||||
|
input_ids < 0,
|
||||||
|
future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
||||||
|
input_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FutureMap:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_running_requests: int,
|
||||||
|
device: torch.device,
|
||||||
|
):
|
||||||
|
self.future_ct = 0
|
||||||
|
# A factor of 3 is used to avoid collision in the circular buffer.
|
||||||
|
self.future_limit = max_running_requests * 3
|
||||||
|
# A factor of 5 is used to ensure the buffer is large enough.
|
||||||
|
self.future_buffer_len = max_running_requests * 5
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
self.token_ids_buf = torch.empty(
|
||||||
|
(self.future_buffer_len,), dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
|
||||||
|
def update_ct(self, bs: int) -> int:
|
||||||
|
"""Update the circular buffer pointer and return the current pointer."""
|
||||||
|
cur_future_ct = self.future_ct
|
||||||
|
self.future_ct = (cur_future_ct + bs) % self.future_limit
|
||||||
|
return cur_future_ct
|
||||||
|
|
||||||
|
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
|
||||||
|
input_ids = model_worker_batch.input_ids
|
||||||
|
_resolve_future_token_ids(input_ids, self.token_ids_buf)
|
||||||
|
|
||||||
|
def update_next_future(self, future_ct: int, bs: int):
|
||||||
|
return torch.arange(
|
||||||
|
-(future_ct + 1),
|
||||||
|
-(future_ct + 1 + bs),
|
||||||
|
-1,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
|
||||||
|
self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids
|
||||||
@@ -36,10 +36,11 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.overlap_utils import FutureMap
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
|
from sglang.srt.utils import DynamicGradMode
|
||||||
from sglang.utils import get_exception_traceback
|
from sglang.utils import get_exception_traceback
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -48,15 +49,6 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
|
||||||
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
|
||||||
input_ids[:] = torch.where(
|
|
||||||
input_ids < 0,
|
|
||||||
future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
|
||||||
input_ids,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TpModelWorkerClient:
|
class TpModelWorkerClient:
|
||||||
"""A tensor parallel model worker."""
|
"""A tensor parallel model worker."""
|
||||||
|
|
||||||
@@ -79,11 +71,7 @@ class TpModelWorkerClient:
|
|||||||
self.gpu_id = gpu_id
|
self.gpu_id = gpu_id
|
||||||
|
|
||||||
# Init future mappings
|
# Init future mappings
|
||||||
self.future_token_ids_ct = 0
|
self.future_map = FutureMap(self.max_running_requests, self.device)
|
||||||
self.future_token_ids_limit = self.max_running_requests * 3
|
|
||||||
self.future_token_ids_map = torch.empty(
|
|
||||||
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# Launch threads
|
# Launch threads
|
||||||
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
|
self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
|
||||||
@@ -153,7 +141,7 @@ class TpModelWorkerClient:
|
|||||||
batch_lists: List = [None] * 2
|
batch_lists: List = [None] * 2
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
|
model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
|
||||||
if not model_worker_batch:
|
if not model_worker_batch:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -169,8 +157,7 @@ class TpModelWorkerClient:
|
|||||||
copy_done = torch.get_device_module(self.device).Event()
|
copy_done = torch.get_device_module(self.device).Event()
|
||||||
|
|
||||||
# Resolve future tokens in the input
|
# Resolve future tokens in the input
|
||||||
input_ids = model_worker_batch.input_ids
|
self.future_map.resolve_future(model_worker_batch)
|
||||||
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
|
||||||
|
|
||||||
# Run forward
|
# Run forward
|
||||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||||
@@ -187,9 +174,9 @@ class TpModelWorkerClient:
|
|||||||
if model_worker_batch.is_prefill_only:
|
if model_worker_batch.is_prefill_only:
|
||||||
# For prefill-only requests, create dummy token IDs on CPU
|
# For prefill-only requests, create dummy token IDs on CPU
|
||||||
next_token_ids = torch.zeros(bs, dtype=torch.long)
|
next_token_ids = torch.zeros(bs, dtype=torch.long)
|
||||||
self.future_token_ids_map[
|
|
||||||
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
|
# store the future indices into future map
|
||||||
] = next_token_ids
|
self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
|
||||||
|
|
||||||
# Copy results to the CPU
|
# Copy results to the CPU
|
||||||
if model_worker_batch.return_logprob:
|
if model_worker_batch.return_logprob:
|
||||||
@@ -255,20 +242,14 @@ class TpModelWorkerClient:
|
|||||||
sync_event.record(self.scheduler_stream)
|
sync_event.record(self.scheduler_stream)
|
||||||
|
|
||||||
# Push a new batch to the queue
|
# Push a new batch to the queue
|
||||||
self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
|
|
||||||
|
|
||||||
# Allocate output future objects
|
|
||||||
bs = len(model_worker_batch.seq_lens)
|
bs = len(model_worker_batch.seq_lens)
|
||||||
future_next_token_ids = torch.arange(
|
cur_future_map_ct = self.future_map.update_ct(bs)
|
||||||
-(self.future_token_ids_ct + 1),
|
self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
|
||||||
-(self.future_token_ids_ct + 1 + bs),
|
|
||||||
-1,
|
# get this forward batch's future token ids
|
||||||
dtype=torch.int64,
|
future_next_token_ids = self.future_map.update_next_future(
|
||||||
device=self.device,
|
cur_future_map_ct, bs
|
||||||
)
|
)
|
||||||
self.future_token_ids_ct = (
|
|
||||||
self.future_token_ids_ct + bs
|
|
||||||
) % self.future_token_ids_limit
|
|
||||||
return None, future_next_token_ids, False
|
return None, future_next_token_ids, False
|
||||||
|
|
||||||
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
||||||
|
|||||||
Reference in New Issue
Block a user