Unify the memory pool api and tp worker API (#1724)

This commit is contained in:
Lianmin Zheng
2024-10-19 23:19:26 -07:00
committed by GitHub
parent 95946271af
commit 59cbf47626
8 changed files with 87 additions and 25 deletions

View File

@@ -51,6 +51,7 @@ from sglang.srt.managers.schedule_batch import (
ImageInputs,
Req,
ScheduleBatch,
global_server_args_dict,
)
from sglang.srt.managers.schedule_policy import (
AddReqResult,
@@ -144,25 +145,27 @@ class Scheduler:
)
# Launch a tensor parallel worker
self.tp_worker = TpModelWorker(
if self.server_args.enable_overlap_schedule:
TpWorkerClass = TpModelWorker
else:
TpWorkerClass = TpModelWorker
self.tp_worker = TpWorkerClass(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
dp_rank=dp_rank,
nccl_port=port_args.nccl_port,
)
# Init states for overlap schedule
if self.server_args.enable_overlap_schedule:
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
self.resolve_next_token_ids = (
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
)
self.forward_batch_generation = (
self.tp_worker.forward_batch_generation_non_blocking
)
else:
self.forward_batch_generation = self.tp_worker.forward_batch_generation
self.resolve_next_token_ids = lambda bid, x: x.tolist()
self.forward_batch_generation = self.tp_worker.forward_batch_generation
# Get token and memory info from the model worker
(
@@ -172,9 +175,14 @@ class Scheduler:
self.max_req_input_len,
self.random_seed,
self.device,
) = self.tp_worker.get_token_and_memory_info()
worker_global_server_args_dict,
_,
_,
_,
) = self.tp_worker.get_worker_info()
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
# Print debug info
@@ -266,6 +274,7 @@ class Scheduler:
@torch.inference_mode()
def event_loop_normal(self):
"""A normal blocking scheduler loop."""
self.last_batch = None
while True:
@@ -296,6 +305,7 @@ class Scheduler:
@torch.inference_mode()
def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue = deque()
self.last_batch = None
@@ -572,6 +582,7 @@ class Scheduler:
else set([])
)
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if (
self.lora_paths
@@ -673,6 +684,7 @@ class Scheduler:
return new_batch
def update_running_batch(self):
"""Update the current running decoding batch."""
global test_retract
batch = self.running_batch
@@ -712,6 +724,7 @@ class Scheduler:
batch.prepare_for_decode()
def run_batch(self, batch: ScheduleBatch):
"""Run a batch."""
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
model_worker_batch = batch.get_model_worker_batch()
@@ -933,6 +946,7 @@ class Scheduler:
return num_input_logprobs
def stream_output(self, reqs: List[Req]):
"""Stream the output to detokenizer."""
output_rids = []
output_meta_info = []
output_finished_reason: List[BaseFinishReason] = []
@@ -1030,6 +1044,7 @@ class Scheduler:
)
def flush_cache(self):
"""Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and (
self.running_batch is None or len(self.running_batch.reqs) == 0
):
@@ -1070,6 +1085,7 @@ class Scheduler:
break
def update_weights(self, recv_req: UpdateWeightReqInput):
"""In-place update of the weights."""
success, message = self.tp_worker.update_weights(recv_req)
if success:
flash_cache_success = self.flush_cache()