Split the overlapped version of TpModelWorkerClient into a separate file (#1726)
This commit is contained in:
@@ -59,6 +59,7 @@ from sglang.srt.managers.schedule_policy import (
|
||||
SchedulePolicy,
|
||||
)
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -146,9 +147,14 @@ class Scheduler:
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
if self.server_args.enable_overlap_schedule:
|
||||
TpWorkerClass = TpModelWorker
|
||||
TpWorkerClass = TpModelWorkerClient
|
||||
self.resolve_next_token_ids = (
|
||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||
)
|
||||
else:
|
||||
TpWorkerClass = TpModelWorker
|
||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||
|
||||
self.tp_worker = TpWorkerClass(
|
||||
server_args=server_args,
|
||||
gpu_id=gpu_id,
|
||||
@@ -156,16 +162,6 @@ class Scheduler:
|
||||
dp_rank=dp_rank,
|
||||
nccl_port=port_args.nccl_port,
|
||||
)
|
||||
if self.server_args.enable_overlap_schedule:
|
||||
self.resolve_next_token_ids = (
|
||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||
)
|
||||
self.forward_batch_generation = (
|
||||
self.tp_worker.forward_batch_generation_non_blocking
|
||||
)
|
||||
else:
|
||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
||||
|
||||
# Get token and memory info from the model worker
|
||||
(
|
||||
@@ -728,7 +724,7 @@ class Scheduler:
|
||||
if self.is_generation:
|
||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = self.forward_batch_generation(
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user