Simplify the interface of tp_worker (#1718)
This commit is contained in:
@@ -91,6 +91,7 @@ class Scheduler:
|
|||||||
port_args: PortArgs,
|
port_args: PortArgs,
|
||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
|
dp_rank: Optional[int],
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
@@ -144,13 +145,24 @@ class Scheduler:
|
|||||||
|
|
||||||
# Launch a tensor parallel worker
|
# Launch a tensor parallel worker
|
||||||
self.tp_worker = TpModelWorker(
|
self.tp_worker = TpModelWorker(
|
||||||
|
server_args=server_args,
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
server_args=server_args,
|
dp_rank=dp_rank,
|
||||||
nccl_port=port_args.nccl_port,
|
nccl_port=port_args.nccl_port,
|
||||||
)
|
)
|
||||||
self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group
|
|
||||||
self.device = self.tp_worker.device
|
# Init states for overlap schedule
|
||||||
|
if self.server_args.enable_overlap_schedule:
|
||||||
|
self.forward_batch_generation = (
|
||||||
|
self.tp_worker.forward_batch_generation_non_blocking
|
||||||
|
)
|
||||||
|
self.resolve_next_token_ids = (
|
||||||
|
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
||||||
|
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||||
|
|
||||||
# Get token and memory info from the model worker
|
# Get token and memory info from the model worker
|
||||||
(
|
(
|
||||||
@@ -159,11 +171,11 @@ class Scheduler:
|
|||||||
self.max_running_requests,
|
self.max_running_requests,
|
||||||
self.max_req_input_len,
|
self.max_req_input_len,
|
||||||
self.random_seed,
|
self.random_seed,
|
||||||
|
self.device,
|
||||||
) = self.tp_worker.get_token_and_memory_info()
|
) = self.tp_worker.get_token_and_memory_info()
|
||||||
|
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
||||||
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
||||||
set_random_seed(self.random_seed)
|
set_random_seed(self.random_seed)
|
||||||
self.pad_input_ids_func = getattr(
|
|
||||||
self.tp_worker.model_runner.model, "pad_input_ids", None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Print debug info
|
# Print debug info
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -173,9 +185,8 @@ class Scheduler:
|
|||||||
f"context_len={self.model_config.context_len}"
|
f"context_len={self.model_config.context_len}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init cache
|
# Init memory pool and cache
|
||||||
self.req_to_token_pool = self.tp_worker.model_runner.req_to_token_pool
|
self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
|
||||||
self.token_to_kv_pool = self.tp_worker.model_runner.token_to_kv_pool
|
|
||||||
|
|
||||||
if (
|
if (
|
||||||
server_args.chunked_prefill_size is not None
|
server_args.chunked_prefill_size is not None
|
||||||
@@ -253,20 +264,6 @@ class Scheduler:
|
|||||||
with_stack=True,
|
with_stack=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init states for overlap schedule
|
|
||||||
if self.server_args.enable_overlap_schedule:
|
|
||||||
self.forward_batch_generation = (
|
|
||||||
self.tp_worker.forward_batch_generation_non_blocking
|
|
||||||
)
|
|
||||||
self.resolve_next_token_ids = (
|
|
||||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
|
||||||
)
|
|
||||||
self.cache_finished_req = self.tree_cache.cache_finished_req
|
|
||||||
else:
|
|
||||||
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
|
||||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
|
||||||
self.cache_finished_req = self.tree_cache.cache_finished_req
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def event_loop_normal(self):
|
def event_loop_normal(self):
|
||||||
self.last_batch = None
|
self.last_batch = None
|
||||||
@@ -779,7 +776,7 @@ class Scheduler:
|
|||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
@@ -808,7 +805,7 @@ class Scheduler:
|
|||||||
req.check_finished()
|
req.check_finished()
|
||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
else:
|
else:
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
|
||||||
@@ -845,7 +842,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if req.finished():
|
if req.finished():
|
||||||
self.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
|
|
||||||
if req.return_logprob:
|
if req.return_logprob:
|
||||||
req.output_token_logprobs.append(
|
req.output_token_logprobs.append(
|
||||||
@@ -1069,7 +1066,7 @@ class Scheduler:
|
|||||||
for req in self.running_batch.reqs:
|
for req in self.running_batch.reqs:
|
||||||
if req.rid == recv_req.rid and not req.finished():
|
if req.rid == recv_req.rid and not req.finished():
|
||||||
req.finished_reason = FINISH_ABORT()
|
req.finished_reason = FINISH_ABORT()
|
||||||
self.cache_finished_req(req)
|
self.tree_cache.cache_finished_req(req)
|
||||||
break
|
break
|
||||||
|
|
||||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||||
@@ -1112,7 +1109,7 @@ def run_scheduler_process(
|
|||||||
suppress_other_loggers()
|
suppress_other_loggers()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
||||||
pipe_writer.send("ready")
|
pipe_writer.send("ready")
|
||||||
if server_args.enable_overlap_schedule:
|
if server_args.enable_overlap_schedule:
|
||||||
scheduler.event_loop_overlap()
|
scheduler.event_loop_overlap()
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import logging
|
|||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from queue import Queue
|
from queue import Queue
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -40,9 +41,10 @@ class TpModelWorker:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
server_args: ServerArgs,
|
||||||
gpu_id: int,
|
gpu_id: int,
|
||||||
tp_rank: int,
|
tp_rank: int,
|
||||||
server_args: ServerArgs,
|
dp_rank: Optional[int],
|
||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
@@ -116,6 +118,19 @@ class TpModelWorker:
|
|||||||
self.max_running_requests,
|
self.max_running_requests,
|
||||||
self.max_req_input_len,
|
self.max_req_input_len,
|
||||||
self.random_seed,
|
self.random_seed,
|
||||||
|
self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_pad_input_ids_func(self):
|
||||||
|
return getattr(self.model_runner.model, "pad_input_ids", None)
|
||||||
|
|
||||||
|
def get_tp_cpu_group(self):
|
||||||
|
return self.model_runner.tp_group.cpu_group
|
||||||
|
|
||||||
|
def get_memory_pool(self):
|
||||||
|
return (
|
||||||
|
self.model_runner.req_to_token_pool,
|
||||||
|
self.model_runner.token_to_kv_pool,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_overlap_status(self):
|
def init_overlap_status(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user