Split the overlapped version of TpModelWorkerClient into a separate file (#1726)

This commit is contained in:
Lianmin Zheng
2024-10-20 00:29:29 -07:00
committed by GitHub
parent 593b19f29d
commit b48edff67f
7 changed files with 217 additions and 131 deletions

View File

@@ -59,6 +59,7 @@ from sglang.srt.managers.schedule_policy import (
SchedulePolicy,
)
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.server_args import PortArgs, ServerArgs
@@ -146,9 +147,14 @@ class Scheduler:
# Launch a tensor parallel worker
if self.server_args.enable_overlap_schedule:
TpWorkerClass = TpModelWorker
TpWorkerClass = TpModelWorkerClient
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
else:
TpWorkerClass = TpModelWorker
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.tp_worker = TpWorkerClass(
server_args=server_args,
gpu_id=gpu_id,
@@ -156,16 +162,6 @@ class Scheduler:
dp_rank=dp_rank,
nccl_port=port_args.nccl_port,
)
if self.server_args.enable_overlap_schedule:
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
else:
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.forward_batch_generation = self.tp_worker.forward_batch_generation
# Get token and memory info from the model worker
(
@@ -728,7 +724,7 @@ class Scheduler:
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.forward_batch_generation(
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
else: