Simplify the interface of tp_worker (#1718)
This commit is contained in:
@@ -91,6 +91,7 @@ class Scheduler:
|
||||
port_args: PortArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
):
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
@@ -144,13 +145,24 @@ class Scheduler:
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
self.tp_worker = TpModelWorker(
|
||||
server_args=server_args,
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
server_args=server_args,
|
||||
dp_rank=dp_rank,
|
||||
nccl_port=port_args.nccl_port,
|
||||
)
|
||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
||||
self.device = self.tp_worker.device
|
||||
|
||||
# Init states for overlap schedule
|
||||
if self.server_args.enable_overlap_schedule:
|
||||
self.forward_batch_generation = (
|
||||
self.tp_worker.forward_batch_generation_non_blocking
|
||||
)
|
||||
self.resolve_next_token_ids = (
|
||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||
)
|
||||
else:
|
||||
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||
|
||||
# Get token and memory info from the model worker
|
||||
(
|
||||
@@ -159,11 +171,11 @@ class Scheduler:
|
||||
self.max_running_requests,
|
||||
self.max_req_input_len,
|
||||
self.random_seed,
|
||||
self.device,
|
||||
) = self.tp_worker.get_token_and_memory_info()
|
||||
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
||||
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
||||
set_random_seed(self.random_seed)
|
||||
self.pad_input_ids_func = getattr(
|
||||
self.tp_worker.model_runner.model, "pad_input_ids", None
|
||||
)
|
||||
|
||||
# Print debug info
|
||||
logger.info(
|
||||
@@ -173,9 +185,8 @@ class Scheduler:
|
||||
f"context_len={self.model_config.context_len}"
|
||||
)
|
||||
|
||||
# Init cache
|
||||
self.req_to_token_pool = self.tp_worker.model_runner.req_to_token_pool
|
||||
self.token_to_kv_pool = self.tp_worker.model_runner.token_to_kv_pool
|
||||
# Init memory pool and cache
|
||||
self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
|
||||
|
||||
if (
|
||||
server_args.chunked_prefill_size is not None
|
||||
@@ -253,20 +264,6 @@ class Scheduler:
|
||||
with_stack=True,
|
||||
)
|
||||
|
||||
# Init states for overlap schedule
|
||||
if self.server_args.enable_overlap_schedule:
|
||||
self.forward_batch_generation = (
|
||||
self.tp_worker.forward_batch_generation_non_blocking
|
||||
)
|
||||
self.resolve_next_token_ids = (
|
||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||
)
|
||||
self.cache_finished_req = self.tree_cache.cache_finished_req
|
||||
else:
|
||||
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||
self.cache_finished_req = self.tree_cache.cache_finished_req
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop_normal(self):
|
||||
self.last_batch = None
|
||||
@@ -779,7 +776,7 @@ class Scheduler:
|
||||
req.check_finished()
|
||||
|
||||
if req.finished():
|
||||
self.cache_finished_req(req)
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
|
||||
@@ -808,7 +805,7 @@ class Scheduler:
|
||||
req.check_finished()
|
||||
|
||||
if req.finished():
|
||||
self.cache_finished_req(req)
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
else:
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
|
||||
@@ -845,7 +842,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
if req.finished():
|
||||
self.cache_finished_req(req)
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
|
||||
if req.return_logprob:
|
||||
req.output_token_logprobs.append(
|
||||
@@ -1069,7 +1066,7 @@ class Scheduler:
|
||||
for req in self.running_batch.reqs:
|
||||
if req.rid == recv_req.rid and not req.finished():
|
||||
req.finished_reason = FINISH_ABORT()
|
||||
self.cache_finished_req(req)
|
||||
self.tree_cache.cache_finished_req(req)
|
||||
break
|
||||
|
||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||
@@ -1112,7 +1109,7 @@ def run_scheduler_process(
|
||||
suppress_other_loggers()
|
||||
|
||||
try:
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
||||
pipe_writer.send("ready")
|
||||
if server_args.enable_overlap_schedule:
|
||||
scheduler.event_loop_overlap()
|
||||
|
||||
@@ -20,6 +20,7 @@ import logging
|
||||
import threading
|
||||
import time
|
||||
from queue import Queue
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -40,9 +41,10 @@ class TpModelWorker:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
dp_rank: Optional[int],
|
||||
nccl_port: int,
|
||||
):
|
||||
# Parse args
|
||||
@@ -116,6 +118,19 @@ class TpModelWorker:
|
||||
self.max_running_requests,
|
||||
self.max_req_input_len,
|
||||
self.random_seed,
|
||||
self.device,
|
||||
)
|
||||
|
||||
def get_pad_input_ids_func(self):
|
||||
return getattr(self.model_runner.model, "pad_input_ids", None)
|
||||
|
||||
def get_tp_cpu_group(self):
|
||||
return self.model_runner.tp_group.cpu_group
|
||||
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.model_runner.req_to_token_pool,
|
||||
self.model_runner.token_to_kv_pool,
|
||||
)
|
||||
|
||||
def init_overlap_status(self):
|
||||
|
||||
Reference in New Issue
Block a user