Unify the memory pool api and tp worker API (#1724)
This commit is contained in:
@@ -23,6 +23,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
||||
It contains high-level scheduling data. Most of the data is on the CPU.
|
||||
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
||||
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
|
||||
It will be transformed from CPU scheduler to GPU model runner.
|
||||
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||
"""
|
||||
@@ -522,12 +524,12 @@ class ScheduleBatch:
|
||||
assert seq_len - pre_len == req.extend_input_len
|
||||
|
||||
if pre_len > 0:
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
|
||||
req.prefix_indices
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
||||
)
|
||||
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
||||
out_cache_loc[pt : pt + req.extend_input_len]
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(pre_len, seq_len)),
|
||||
out_cache_loc[pt : pt + req.extend_input_len],
|
||||
)
|
||||
|
||||
# Compute the relative logprob_start_len in an extend batch
|
||||
@@ -765,9 +767,8 @@ class ScheduleBatch:
|
||||
# Alloc mem
|
||||
bs = len(self.reqs)
|
||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||
|
||||
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
|
||||
self.out_cache_loc
|
||||
self.req_to_token_pool.write(
|
||||
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||
)
|
||||
self.seq_lens.add_(1)
|
||||
|
||||
@@ -848,7 +849,6 @@ class ScheduleBatch:
|
||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||
image_inputs = [r.image_inputs for r in self.reqs]
|
||||
|
||||
lora_paths = [req.lora_path for req in self.reqs]
|
||||
if self.has_regex:
|
||||
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
||||
self.sampling_info.regex_fsm_states = [
|
||||
@@ -869,13 +869,14 @@ class ScheduleBatch:
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens,
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||
image_inputs=image_inputs,
|
||||
lora_paths=lora_paths,
|
||||
lora_paths=[req.lora_path for req in self.reqs],
|
||||
sampling_info=self.sampling_info,
|
||||
mrope_positions_delta=mrope_positions_delta,
|
||||
)
|
||||
@@ -911,6 +912,9 @@ class ModelWorkerBatch:
|
||||
# The indices of output tokens in the token_to_kv_pool
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
# The memory pool operation records
|
||||
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
|
||||
|
||||
# For logprob
|
||||
return_logprob: bool
|
||||
top_logprobs_nums: Optional[List[int]]
|
||||
@@ -940,6 +944,7 @@ class ModelWorkerBatch:
|
||||
req_pool_indices=self.req_pool_indices,
|
||||
seq_lens=self.seq_lens.clone(),
|
||||
out_cache_loc=self.out_cache_loc,
|
||||
req_to_token_pool_records=self.req_to_token_pool_records,
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
extend_seq_lens=self.extend_seq_lens,
|
||||
@@ -950,3 +955,14 @@ class ModelWorkerBatch:
|
||||
sampling_info=self.sampling_info.copy(),
|
||||
mrope_positions_delta=self.mrope_positions_delta,
|
||||
)
|
||||
|
||||
def to(self, device: str):
|
||||
self.input_ids = self.input_ids.to(device, non_blocking=True)
|
||||
self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
|
||||
self.seq_lens = self.seq_lens.to(device, non_blocking=True)
|
||||
self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
|
||||
self.req_to_token_pool_records = [
|
||||
(x, y.to(device, non_blocking=True))
|
||||
for x, y in self.req_to_token_pool_records
|
||||
]
|
||||
self.sampling_info.to(device)
|
||||
|
||||
@@ -51,6 +51,7 @@ from sglang.srt.managers.schedule_batch import (
|
||||
ImageInputs,
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.managers.schedule_policy import (
|
||||
AddReqResult,
|
||||
@@ -144,25 +145,27 @@ class Scheduler:
|
||||
)
|
||||
|
||||
# Launch a tensor parallel worker
|
||||
self.tp_worker = TpModelWorker(
|
||||
if self.server_args.enable_overlap_schedule:
|
||||
TpWorkerClass = TpModelWorker
|
||||
else:
|
||||
TpWorkerClass = TpModelWorker
|
||||
self.tp_worker = TpWorkerClass(
|
||||
server_args=server_args,
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
dp_rank=dp_rank,
|
||||
nccl_port=port_args.nccl_port,
|
||||
)
|
||||
|
||||
# Init states for overlap schedule
|
||||
if self.server_args.enable_overlap_schedule:
|
||||
self.forward_batch_generation = (
|
||||
self.tp_worker.forward_batch_generation_non_blocking
|
||||
)
|
||||
self.resolve_next_token_ids = (
|
||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||
)
|
||||
self.forward_batch_generation = (
|
||||
self.tp_worker.forward_batch_generation_non_blocking
|
||||
)
|
||||
else:
|
||||
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
||||
|
||||
# Get token and memory info from the model worker
|
||||
(
|
||||
@@ -172,9 +175,14 @@ class Scheduler:
|
||||
self.max_req_input_len,
|
||||
self.random_seed,
|
||||
self.device,
|
||||
) = self.tp_worker.get_token_and_memory_info()
|
||||
worker_global_server_args_dict,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
) = self.tp_worker.get_worker_info()
|
||||
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
||||
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
||||
global_server_args_dict.update(worker_global_server_args_dict)
|
||||
set_random_seed(self.random_seed)
|
||||
|
||||
# Print debug info
|
||||
@@ -266,6 +274,7 @@ class Scheduler:
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop_normal(self):
|
||||
"""A normal blocking scheduler loop."""
|
||||
self.last_batch = None
|
||||
|
||||
while True:
|
||||
@@ -296,6 +305,7 @@ class Scheduler:
|
||||
|
||||
@torch.inference_mode()
|
||||
def event_loop_overlap(self):
|
||||
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||
result_queue = deque()
|
||||
|
||||
self.last_batch = None
|
||||
@@ -572,6 +582,7 @@ class Scheduler:
|
||||
else set([])
|
||||
)
|
||||
|
||||
# Get requests from the waiting queue to a new prefill batch
|
||||
for req in self.waiting_queue:
|
||||
if (
|
||||
self.lora_paths
|
||||
@@ -673,6 +684,7 @@ class Scheduler:
|
||||
return new_batch
|
||||
|
||||
def update_running_batch(self):
|
||||
"""Update the current running decoding batch."""
|
||||
global test_retract
|
||||
batch = self.running_batch
|
||||
|
||||
@@ -712,6 +724,7 @@ class Scheduler:
|
||||
batch.prepare_for_decode()
|
||||
|
||||
def run_batch(self, batch: ScheduleBatch):
|
||||
"""Run a batch."""
|
||||
if self.is_generation:
|
||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
@@ -933,6 +946,7 @@ class Scheduler:
|
||||
return num_input_logprobs
|
||||
|
||||
def stream_output(self, reqs: List[Req]):
|
||||
"""Stream the output to detokenizer."""
|
||||
output_rids = []
|
||||
output_meta_info = []
|
||||
output_finished_reason: List[BaseFinishReason] = []
|
||||
@@ -1030,6 +1044,7 @@ class Scheduler:
|
||||
)
|
||||
|
||||
def flush_cache(self):
|
||||
"""Flush the memory pool and cache."""
|
||||
if len(self.waiting_queue) == 0 and (
|
||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||
):
|
||||
@@ -1070,6 +1085,7 @@ class Scheduler:
|
||||
break
|
||||
|
||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||
"""In-place update of the weights."""
|
||||
success, message = self.tp_worker.update_weights(recv_req)
|
||||
if success:
|
||||
flash_cache_success = self.flush_cache()
|
||||
|
||||
@@ -27,7 +27,7 @@ import torch
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -111,7 +111,7 @@ class TpModelWorker:
|
||||
if server_args.enable_overlap_schedule:
|
||||
self.init_overlap_status()
|
||||
|
||||
def get_token_and_memory_info(self):
|
||||
def get_worker_info(self):
|
||||
return (
|
||||
self.max_total_num_tokens,
|
||||
self.max_prefill_tokens,
|
||||
@@ -119,6 +119,10 @@ class TpModelWorker:
|
||||
self.max_req_input_len,
|
||||
self.random_seed,
|
||||
self.device,
|
||||
global_server_args_dict,
|
||||
self.model_runner.req_to_token_pool.size,
|
||||
self.model_runner.req_to_token_pool.max_context_len,
|
||||
self.model_runner.token_to_kv_pool.size,
|
||||
)
|
||||
|
||||
def get_pad_input_ids_func(self):
|
||||
|
||||
@@ -56,6 +56,12 @@ class ReqToTokenPool:
|
||||
def clear(self):
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
def write(self, indices, values):
|
||||
self.req_to_token[indices] = values
|
||||
|
||||
def get_write_records(self):
|
||||
return None
|
||||
|
||||
|
||||
class BaseTokenToKVPool:
|
||||
"""A memory pool that maps a token to its kv cache locations"""
|
||||
@@ -68,12 +74,12 @@ class BaseTokenToKVPool:
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype == torch.float8_e5m2:
|
||||
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.free_slots = None
|
||||
self.is_not_in_free_group = True
|
||||
|
||||
@@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache):
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node = self.match_prefix(token_ids)
|
||||
assert len(new_indices) == len(token_ids)
|
||||
self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
|
||||
] = new_indices[len(req.prefix_indices) :]
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
)
|
||||
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
|
||||
@@ -25,6 +25,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
||||
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
||||
It contains high-level scheduling data. Most of the data is on the CPU.
|
||||
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
||||
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
|
||||
It will be transformed from CPU scheduler to GPU model runner.
|
||||
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||
"""
|
||||
|
||||
@@ -131,6 +131,13 @@ class ModelRunner:
|
||||
]:
|
||||
server_args.disable_cuda_graph = True
|
||||
|
||||
if self.server_args.enable_overlap_schedule:
|
||||
logger.warning(
|
||||
"Overlap scheduler is enabled. This is an experimental feature. "
|
||||
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
|
||||
"and embedding APIs are not supported and will lead to wrong results."
|
||||
)
|
||||
|
||||
# Global vars
|
||||
if server_args.show_time_cost:
|
||||
enable_show_time_cost()
|
||||
|
||||
@@ -78,7 +78,7 @@ class SamplingBatchInfo:
|
||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||
is_all_greedy=top_ks.max().item() <= 1,
|
||||
vocab_size=vocab_size,
|
||||
device=batch.input_ids.device,
|
||||
device=device,
|
||||
)
|
||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||
|
||||
@@ -224,3 +224,13 @@ class SamplingBatchInfo:
|
||||
vocab_size=self.vocab_size,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def to(self, device: str):
|
||||
for item in [
|
||||
"temperatures",
|
||||
"top_ps",
|
||||
"top_ks",
|
||||
"min_ps",
|
||||
]:
|
||||
value = getattr(self, item)
|
||||
setattr(self, item, value.to(device, non_blocking=True))
|
||||
|
||||
Reference in New Issue
Block a user