Unify the memory pool api and tp worker API (#1724)
This commit is contained in:
@@ -23,6 +23,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|||||||
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
||||||
It contains high-level scheduling data. Most of the data is on the CPU.
|
It contains high-level scheduling data. Most of the data is on the CPU.
|
||||||
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
||||||
|
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
|
||||||
|
It will be transformed from CPU scheduler to GPU model runner.
|
||||||
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
||||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||||
"""
|
"""
|
||||||
@@ -522,12 +524,12 @@ class ScheduleBatch:
|
|||||||
assert seq_len - pre_len == req.extend_input_len
|
assert seq_len - pre_len == req.extend_input_len
|
||||||
|
|
||||||
if pre_len > 0:
|
if pre_len > 0:
|
||||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, :pre_len] = (
|
self.req_to_token_pool.write(
|
||||||
req.prefix_indices
|
(req.req_pool_idx, slice(0, pre_len)), req.prefix_indices
|
||||||
)
|
)
|
||||||
|
self.req_to_token_pool.write(
|
||||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
(req.req_pool_idx, slice(pre_len, seq_len)),
|
||||||
out_cache_loc[pt : pt + req.extend_input_len]
|
out_cache_loc[pt : pt + req.extend_input_len],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute the relative logprob_start_len in an extend batch
|
# Compute the relative logprob_start_len in an extend batch
|
||||||
@@ -765,9 +767,8 @@ class ScheduleBatch:
|
|||||||
# Alloc mem
|
# Alloc mem
|
||||||
bs = len(self.reqs)
|
bs = len(self.reqs)
|
||||||
self.out_cache_loc = self.alloc_token_slots(bs)
|
self.out_cache_loc = self.alloc_token_slots(bs)
|
||||||
|
self.req_to_token_pool.write(
|
||||||
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = (
|
(self.req_pool_indices, self.seq_lens), self.out_cache_loc
|
||||||
self.out_cache_loc
|
|
||||||
)
|
)
|
||||||
self.seq_lens.add_(1)
|
self.seq_lens.add_(1)
|
||||||
|
|
||||||
@@ -848,7 +849,6 @@ class ScheduleBatch:
|
|||||||
extend_logprob_start_lens = self.extend_logprob_start_lens
|
extend_logprob_start_lens = self.extend_logprob_start_lens
|
||||||
image_inputs = [r.image_inputs for r in self.reqs]
|
image_inputs = [r.image_inputs for r in self.reqs]
|
||||||
|
|
||||||
lora_paths = [req.lora_path for req in self.reqs]
|
|
||||||
if self.has_regex:
|
if self.has_regex:
|
||||||
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
self.sampling_info.regex_fsms = [req.regex_fsm for req in self.reqs]
|
||||||
self.sampling_info.regex_fsm_states = [
|
self.sampling_info.regex_fsm_states = [
|
||||||
@@ -869,13 +869,14 @@ class ScheduleBatch:
|
|||||||
req_pool_indices=self.req_pool_indices,
|
req_pool_indices=self.req_pool_indices,
|
||||||
seq_lens=self.seq_lens,
|
seq_lens=self.seq_lens,
|
||||||
out_cache_loc=self.out_cache_loc,
|
out_cache_loc=self.out_cache_loc,
|
||||||
|
req_to_token_pool_records=self.req_to_token_pool.get_write_records(),
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
top_logprobs_nums=self.top_logprobs_nums,
|
top_logprobs_nums=self.top_logprobs_nums,
|
||||||
extend_seq_lens=extend_seq_lens,
|
extend_seq_lens=extend_seq_lens,
|
||||||
extend_prefix_lens=extend_prefix_lens,
|
extend_prefix_lens=extend_prefix_lens,
|
||||||
extend_logprob_start_lens=extend_logprob_start_lens,
|
extend_logprob_start_lens=extend_logprob_start_lens,
|
||||||
image_inputs=image_inputs,
|
image_inputs=image_inputs,
|
||||||
lora_paths=lora_paths,
|
lora_paths=[req.lora_path for req in self.reqs],
|
||||||
sampling_info=self.sampling_info,
|
sampling_info=self.sampling_info,
|
||||||
mrope_positions_delta=mrope_positions_delta,
|
mrope_positions_delta=mrope_positions_delta,
|
||||||
)
|
)
|
||||||
@@ -911,6 +912,9 @@ class ModelWorkerBatch:
|
|||||||
# The indices of output tokens in the token_to_kv_pool
|
# The indices of output tokens in the token_to_kv_pool
|
||||||
out_cache_loc: torch.Tensor
|
out_cache_loc: torch.Tensor
|
||||||
|
|
||||||
|
# The memory pool operation records
|
||||||
|
req_to_token_pool_records: Optional[List[Tuple[Tuple, torch.Tensor]]]
|
||||||
|
|
||||||
# For logprob
|
# For logprob
|
||||||
return_logprob: bool
|
return_logprob: bool
|
||||||
top_logprobs_nums: Optional[List[int]]
|
top_logprobs_nums: Optional[List[int]]
|
||||||
@@ -940,6 +944,7 @@ class ModelWorkerBatch:
|
|||||||
req_pool_indices=self.req_pool_indices,
|
req_pool_indices=self.req_pool_indices,
|
||||||
seq_lens=self.seq_lens.clone(),
|
seq_lens=self.seq_lens.clone(),
|
||||||
out_cache_loc=self.out_cache_loc,
|
out_cache_loc=self.out_cache_loc,
|
||||||
|
req_to_token_pool_records=self.req_to_token_pool_records,
|
||||||
return_logprob=self.return_logprob,
|
return_logprob=self.return_logprob,
|
||||||
top_logprobs_nums=self.top_logprobs_nums,
|
top_logprobs_nums=self.top_logprobs_nums,
|
||||||
extend_seq_lens=self.extend_seq_lens,
|
extend_seq_lens=self.extend_seq_lens,
|
||||||
@@ -950,3 +955,14 @@ class ModelWorkerBatch:
|
|||||||
sampling_info=self.sampling_info.copy(),
|
sampling_info=self.sampling_info.copy(),
|
||||||
mrope_positions_delta=self.mrope_positions_delta,
|
mrope_positions_delta=self.mrope_positions_delta,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to(self, device: str):
|
||||||
|
self.input_ids = self.input_ids.to(device, non_blocking=True)
|
||||||
|
self.req_pool_indices = self.req_pool_indices.to(device, non_blocking=True)
|
||||||
|
self.seq_lens = self.seq_lens.to(device, non_blocking=True)
|
||||||
|
self.out_cache_loc = self.out_cache_loc.to(device, non_blocking=True)
|
||||||
|
self.req_to_token_pool_records = [
|
||||||
|
(x, y.to(device, non_blocking=True))
|
||||||
|
for x, y in self.req_to_token_pool_records
|
||||||
|
]
|
||||||
|
self.sampling_info.to(device)
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
ImageInputs,
|
ImageInputs,
|
||||||
Req,
|
Req,
|
||||||
ScheduleBatch,
|
ScheduleBatch,
|
||||||
|
global_server_args_dict,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_policy import (
|
from sglang.srt.managers.schedule_policy import (
|
||||||
AddReqResult,
|
AddReqResult,
|
||||||
@@ -144,25 +145,27 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Launch a tensor parallel worker
|
# Launch a tensor parallel worker
|
||||||
self.tp_worker = TpModelWorker(
|
if self.server_args.enable_overlap_schedule:
|
||||||
|
TpWorkerClass = TpModelWorker
|
||||||
|
else:
|
||||||
|
TpWorkerClass = TpModelWorker
|
||||||
|
self.tp_worker = TpWorkerClass(
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
dp_rank=dp_rank,
|
dp_rank=dp_rank,
|
||||||
nccl_port=port_args.nccl_port,
|
nccl_port=port_args.nccl_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Init states for overlap schedule
|
|
||||||
if self.server_args.enable_overlap_schedule:
|
if self.server_args.enable_overlap_schedule:
|
||||||
self.forward_batch_generation = (
|
|
||||||
self.tp_worker.forward_batch_generation_non_blocking
|
|
||||||
)
|
|
||||||
self.resolve_next_token_ids = (
|
self.resolve_next_token_ids = (
|
||||||
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
lambda bid, x: self.tp_worker.resolve_future_token_ids(bid)
|
||||||
)
|
)
|
||||||
|
self.forward_batch_generation = (
|
||||||
|
self.tp_worker.forward_batch_generation_non_blocking
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
|
||||||
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
self.resolve_next_token_ids = lambda bid, x: x.tolist()
|
||||||
|
self.forward_batch_generation = self.tp_worker.forward_batch_generation
|
||||||
|
|
||||||
# Get token and memory info from the model worker
|
# Get token and memory info from the model worker
|
||||||
(
|
(
|
||||||
@@ -172,9 +175,14 @@ class Scheduler:
|
|||||||
self.max_req_input_len,
|
self.max_req_input_len,
|
||||||
self.random_seed,
|
self.random_seed,
|
||||||
self.device,
|
self.device,
|
||||||
) = self.tp_worker.get_token_and_memory_info()
|
worker_global_server_args_dict,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
_,
|
||||||
|
) = self.tp_worker.get_worker_info()
|
||||||
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
||||||
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
||||||
|
global_server_args_dict.update(worker_global_server_args_dict)
|
||||||
set_random_seed(self.random_seed)
|
set_random_seed(self.random_seed)
|
||||||
|
|
||||||
# Print debug info
|
# Print debug info
|
||||||
@@ -266,6 +274,7 @@ class Scheduler:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def event_loop_normal(self):
|
def event_loop_normal(self):
|
||||||
|
"""A normal blocking scheduler loop."""
|
||||||
self.last_batch = None
|
self.last_batch = None
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
@@ -296,6 +305,7 @@ class Scheduler:
|
|||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def event_loop_overlap(self):
|
def event_loop_overlap(self):
|
||||||
|
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||||
result_queue = deque()
|
result_queue = deque()
|
||||||
|
|
||||||
self.last_batch = None
|
self.last_batch = None
|
||||||
@@ -572,6 +582,7 @@ class Scheduler:
|
|||||||
else set([])
|
else set([])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get requests from the waiting queue to a new prefill batch
|
||||||
for req in self.waiting_queue:
|
for req in self.waiting_queue:
|
||||||
if (
|
if (
|
||||||
self.lora_paths
|
self.lora_paths
|
||||||
@@ -673,6 +684,7 @@ class Scheduler:
|
|||||||
return new_batch
|
return new_batch
|
||||||
|
|
||||||
def update_running_batch(self):
|
def update_running_batch(self):
|
||||||
|
"""Update the current running decoding batch."""
|
||||||
global test_retract
|
global test_retract
|
||||||
batch = self.running_batch
|
batch = self.running_batch
|
||||||
|
|
||||||
@@ -712,6 +724,7 @@ class Scheduler:
|
|||||||
batch.prepare_for_decode()
|
batch.prepare_for_decode()
|
||||||
|
|
||||||
def run_batch(self, batch: ScheduleBatch):
|
def run_batch(self, batch: ScheduleBatch):
|
||||||
|
"""Run a batch."""
|
||||||
if self.is_generation:
|
if self.is_generation:
|
||||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
@@ -933,6 +946,7 @@ class Scheduler:
|
|||||||
return num_input_logprobs
|
return num_input_logprobs
|
||||||
|
|
||||||
def stream_output(self, reqs: List[Req]):
|
def stream_output(self, reqs: List[Req]):
|
||||||
|
"""Stream the output to detokenizer."""
|
||||||
output_rids = []
|
output_rids = []
|
||||||
output_meta_info = []
|
output_meta_info = []
|
||||||
output_finished_reason: List[BaseFinishReason] = []
|
output_finished_reason: List[BaseFinishReason] = []
|
||||||
@@ -1030,6 +1044,7 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def flush_cache(self):
|
def flush_cache(self):
|
||||||
|
"""Flush the memory pool and cache."""
|
||||||
if len(self.waiting_queue) == 0 and (
|
if len(self.waiting_queue) == 0 and (
|
||||||
self.running_batch is None or len(self.running_batch.reqs) == 0
|
self.running_batch is None or len(self.running_batch.reqs) == 0
|
||||||
):
|
):
|
||||||
@@ -1070,6 +1085,7 @@ class Scheduler:
|
|||||||
break
|
break
|
||||||
|
|
||||||
def update_weights(self, recv_req: UpdateWeightReqInput):
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
||||||
|
"""In-place update of the weights."""
|
||||||
success, message = self.tp_worker.update_weights(recv_req)
|
success, message = self.tp_worker.update_weights(recv_req)
|
||||||
if success:
|
if success:
|
||||||
flash_cache_success = self.flush_cache()
|
flash_cache_success = self.flush_cache()
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ import torch
|
|||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||||
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -111,7 +111,7 @@ class TpModelWorker:
|
|||||||
if server_args.enable_overlap_schedule:
|
if server_args.enable_overlap_schedule:
|
||||||
self.init_overlap_status()
|
self.init_overlap_status()
|
||||||
|
|
||||||
def get_token_and_memory_info(self):
|
def get_worker_info(self):
|
||||||
return (
|
return (
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
@@ -119,6 +119,10 @@ class TpModelWorker:
|
|||||||
self.max_req_input_len,
|
self.max_req_input_len,
|
||||||
self.random_seed,
|
self.random_seed,
|
||||||
self.device,
|
self.device,
|
||||||
|
global_server_args_dict,
|
||||||
|
self.model_runner.req_to_token_pool.size,
|
||||||
|
self.model_runner.req_to_token_pool.max_context_len,
|
||||||
|
self.model_runner.token_to_kv_pool.size,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_pad_input_ids_func(self):
|
def get_pad_input_ids_func(self):
|
||||||
|
|||||||
@@ -56,6 +56,12 @@ class ReqToTokenPool:
|
|||||||
def clear(self):
|
def clear(self):
|
||||||
self.free_slots = list(range(self.size))
|
self.free_slots = list(range(self.size))
|
||||||
|
|
||||||
|
def write(self, indices, values):
|
||||||
|
self.req_to_token[indices] = values
|
||||||
|
|
||||||
|
def get_write_records(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class BaseTokenToKVPool:
|
class BaseTokenToKVPool:
|
||||||
"""A memory pool that maps a token to its kv cache locations"""
|
"""A memory pool that maps a token to its kv cache locations"""
|
||||||
@@ -68,12 +74,12 @@ class BaseTokenToKVPool:
|
|||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self.device = device
|
|
||||||
if dtype == torch.float8_e5m2:
|
if dtype == torch.float8_e5m2:
|
||||||
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
||||||
self.store_dtype = torch.uint8
|
self.store_dtype = torch.uint8
|
||||||
else:
|
else:
|
||||||
self.store_dtype = dtype
|
self.store_dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
|
||||||
self.free_slots = None
|
self.free_slots = None
|
||||||
self.is_not_in_free_group = True
|
self.is_not_in_free_group = True
|
||||||
|
|||||||
@@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache):
|
|||||||
# The prefix indices could be updated, reuse it
|
# The prefix indices could be updated, reuse it
|
||||||
new_indices, new_last_node = self.match_prefix(token_ids)
|
new_indices, new_last_node = self.match_prefix(token_ids)
|
||||||
assert len(new_indices) == len(token_ids)
|
assert len(new_indices) == len(token_ids)
|
||||||
self.req_to_token_pool.req_to_token[
|
self.req_to_token_pool.write(
|
||||||
req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
|
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||||
] = new_indices[len(req.prefix_indices) :]
|
new_indices[len(req.prefix_indices) :],
|
||||||
|
)
|
||||||
|
|
||||||
self.dec_lock_ref(req.last_node)
|
self.dec_lock_ref(req.last_node)
|
||||||
self.inc_lock_ref(new_last_node)
|
self.inc_lock_ref(new_last_node)
|
||||||
|
|||||||
@@ -25,6 +25,8 @@ ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
|
|||||||
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
- ScheduleBatch is managed by `scheduler.py::Scheduler`.
|
||||||
It contains high-level scheduling data. Most of the data is on the CPU.
|
It contains high-level scheduling data. Most of the data is on the CPU.
|
||||||
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
- ModelWorkerBatch is managed by `tp_worker.py::TpModelWorker`.
|
||||||
|
It is a subset of `ScheduleBatch` that only contains data related to the model forward on GPU.
|
||||||
|
It will be transformed from CPU scheduler to GPU model runner.
|
||||||
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
- ForwardBatch is managed by `model_runner.py::ModelRunner`.
|
||||||
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
It contains low-level tensor data. Most of the data consists of GPU tensors.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -131,6 +131,13 @@ class ModelRunner:
|
|||||||
]:
|
]:
|
||||||
server_args.disable_cuda_graph = True
|
server_args.disable_cuda_graph = True
|
||||||
|
|
||||||
|
if self.server_args.enable_overlap_schedule:
|
||||||
|
logger.warning(
|
||||||
|
"Overlap scheduler is enabled. This is an experimental feature. "
|
||||||
|
"Sampling penalizer (e.g., frequency and repetition penalty), constrained decoding (e.g., regex, JSON), "
|
||||||
|
"and embedding APIs are not supported and will lead to wrong results."
|
||||||
|
)
|
||||||
|
|
||||||
# Global vars
|
# Global vars
|
||||||
if server_args.show_time_cost:
|
if server_args.show_time_cost:
|
||||||
enable_show_time_cost()
|
enable_show_time_cost()
|
||||||
|
|||||||
@@ -78,7 +78,7 @@ class SamplingBatchInfo:
|
|||||||
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
||||||
is_all_greedy=top_ks.max().item() <= 1,
|
is_all_greedy=top_ks.max().item() <= 1,
|
||||||
vocab_size=vocab_size,
|
vocab_size=vocab_size,
|
||||||
device=batch.input_ids.device,
|
device=device,
|
||||||
)
|
)
|
||||||
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
||||||
|
|
||||||
@@ -224,3 +224,13 @@ class SamplingBatchInfo:
|
|||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to(self, device: str):
|
||||||
|
for item in [
|
||||||
|
"temperatures",
|
||||||
|
"top_ps",
|
||||||
|
"top_ks",
|
||||||
|
"min_ps",
|
||||||
|
]:
|
||||||
|
value = getattr(self, item)
|
||||||
|
setattr(self, item, value.to(device, non_blocking=True))
|
||||||
|
|||||||
Reference in New Issue
Block a user