Unify the memory pool api and tp worker API (#1724)
This commit is contained in:
@@ -51,6 +51,7 @@ from sglang.srt.managers.schedule_batch import (
|
||||
ImageInputs,
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.managers.schedule_policy import (
|
||||
AddReqResult,
|
||||
@@ -144,25 +145,27 @@ class Scheduler:
|
||||
)
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
self.tp_worker = TpModelWorker(
|
||||
if self.server_args.enable_overlap_schedule:
|
||||
TpWorkerClass = TpModelWorker
|
||||
else:
|
||||
TpWorkerClass = TpModelWorker
|
||||
self.tp_worker = TpWorkerClass(
|
||||
server_args=server_args,
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
dp_rank=dp_rank,
|
||||
nccl_port=port_args.nccl_port,
|
||||
)
|
||||
|
||||
# Init states for overlap schedule
|
||||
if self.server_args.enable_overlap_schedule:
|
||||
self.forward_batch_generation = (
|
||||
self.tp_worker.forward_batch_generation_non_blocking
|
||||
)
|
||||
self.resolve_next_token_ids = (
|
||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||
)
|
||||
self.forward_batch_generation = (
|
||||
self.tp_worker.forward_batch_generation_non_blocking
|
||||
)
|
||||
else:
|
||||
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
||||
|
||||
# Get token and memory info from the model worker
|
||||
(
|
||||
@@ -172,9 +175,14 @@ class Scheduler:
|
||||
self.max_req_input_len,
|
||||
self.random_seed,
|
||||
self.device,
|
||||
) = self.tp_worker.get_token_and_memory_info()
|
||||
worker_global_server_args_dict,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self.tp_worker.get_worker_info()
|
||||
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
||||
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
||||
global_server_args_dict.update(worker_global_server_args_dict)
|
||||
set_random_seed(self.random_seed)
|
||||
|
||||
# Print debug info
|
||||
@@ -266,6 +274,7 @@ class Scheduler:
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop_normal(self):
|
||||
"""A normal blocking scheduler loop."""
|
||||
self.last_batch = None
|
||||
|
||||
while True:
|
||||
@@ -296,6 +305,7 @@ class Scheduler:
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop_overlap(self):
|
||||
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||
result_queue = deque()
|
||||
|
||||
self.last_batch = None
|
||||
@@ -572,6 +582,7 @@ class Scheduler:
|
||||
else set([])
|
||||
)
|
||||
|
||||
# Get requests from the waiting queue to a new prefill batch
|
||||
for req in self.waiting_queue:
|
||||
if (
|
||||
self.lora_paths
|
||||
@@ -673,6 +684,7 @@ class Scheduler:
|
||||
return new_batch
|
||||
|
||||
def update_running_batch(self):
|
||||
"""Update the current running decoding batch."""
|
||||
global test_retract
|
||||
batch = self.running_batch
|
||||
|
||||
@@ -712,6 +724,7 @@ class Scheduler:
|
||||
batch.prepare_for_decode()
|
||||
|
||||
def run_batch(self, batch: ScheduleBatch):
|
||||
"""Run a batch."""
|
||||
if self.is_generation:
|
||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
@@ -933,6 +946,7 @@ class Scheduler:
|
||||
return num_input_logprobs
|
||||
|
||||
def stream_output(self, reqs: List[Req]):
|
||||
"""Stream the output to detokenizer."""
|
||||
output_rids = []
|
||||
output_meta_info = []
|
||||
output_finished_reason: List[BaseFinishReason] = []
|
||||
@@ -1030,6 +1044,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
def flush_cache(self):
|
||||
"""Flush the memory pool and cache."""
|
||||
if len(self.waiting_queue) == 0 and (
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
):
|
||||
@@ -1070,6 +1085,7 @@ class Scheduler:
|
||||
break
|
||||
|
||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||
"""In-place update of the weights."""
|
||||
success, message = self.tp_worker.update_weights(recv_req)
|
||||
if success:
|
||||
flash_cache_success = self.flush_cache()
|
||||
|
||||
Reference in New Issue
Block a user