[PP] Add pipeline parallelism (#5724)

This commit is contained in:
Ying Sheng
2025-04-30 18:18:07 -07:00
committed by GitHub
parent e97e57e699
commit 11383cec3c
25 changed files with 1150 additions and 308 deletions

View File

@@ -181,44 +181,62 @@ class DataParallelController:
enable=server_args.enable_memory_saver
)
# Launch tensor parallel scheduler processes
scheduler_pipe_readers = []
tp_size_per_node = server_args.tp_size // server_args.nnodes
nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
tp_rank_range = range(
tp_size_per_node * server_args.node_rank,
tp_size_per_node * (server_args.node_rank + 1),
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
)
for tp_rank in tp_rank_range:
rank_port_args = port_args
if server_args.enable_dp_attention:
# dp attention has different sharding logic
_, _, dp_rank = compute_dp_attention_world_info(
server_args.enable_dp_attention,
tp_rank,
server_args.tp_size,
server_args.dp_size,
pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
pp_rank_range = range(
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
)
for pp_rank in pp_rank_range:
for tp_rank in tp_rank_range:
rank_port_args = port_args
if server_args.enable_dp_attention:
# dp attention has different sharding logic
_, _, dp_rank = compute_dp_attention_world_info(
server_args.enable_dp_attention,
tp_rank,
server_args.tp_size,
server_args.dp_size,
)
# compute zmq ports for this dp rank
rank_port_args = PortArgs.init_new(server_args, dp_rank)
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False)
gpu_id = (
server_args.base_gpu_id
+ base_gpu_id
+ ((pp_rank % pp_size_per_node) * tp_size_per_node)
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
# compute zmq ports for this dp rank
rank_port_args = PortArgs.init_new(server_args, dp_rank)
# Data parallelism resues the tensor parallelism group,
# so all dp ranks should use the same nccl port.
rank_port_args.nccl_port = port_args.nccl_port
reader, writer = mp.Pipe(duplex=False)
gpu_id = (
server_args.base_gpu_id
+ base_gpu_id
+ (tp_rank % tp_size_per_node) * server_args.gpu_id_step
)
proc = mp.Process(
target=run_scheduler_process,
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
self.scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
proc = mp.Process(
target=run_scheduler_process,
args=(
server_args,
rank_port_args,
gpu_id,
tp_rank,
pp_rank,
dp_rank,
writer,
),
)
with memory_saver_adapter.configure_subprocess():
proc.start()
self.scheduler_procs.append(proc)
scheduler_pipe_readers.append(reader)
# Wait for model to finish loading
scheduler_info = []

View File

@@ -66,23 +66,24 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
global_server_args_dict = {
"attention_backend": ServerArgs.attention_backend,
"sampling_backend": ServerArgs.sampling_backend,
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
"torchao_config": ServerArgs.torchao_config,
"enable_nan_detection": ServerArgs.enable_nan_detection,
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
"deepep_mode": ServerArgs.deepep_mode,
"device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"enable_nan_detection": ServerArgs.enable_nan_detection,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"max_micro_batch_size": ServerArgs.max_micro_batch_size,
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
"sampling_backend": ServerArgs.sampling_backend,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"torchao_config": ServerArgs.torchao_config,
"triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
}
logger = logging.getLogger(__name__)
@@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Events
launch_done: Optional[threading.Event] = None
# For chunked prefill in PP
chunked_req: Optional[Req] = None
# Sampling info
sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None
@@ -761,7 +765,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# For extend and mixed chunekd prefill
prefix_lens: List[int] = None
extend_lens: List[int] = None
extend_num_tokens: int = None
extend_num_tokens: Optional[int] = None
decoding_reqs: List[Req] = None
extend_logprob_start_lens: List[int] = None
# It comes empty list if logprob is not required.
@@ -803,6 +807,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
chunked_req: Optional[Req] = None,
):
return_logprob = any(req.return_logprob for req in reqs)
@@ -820,6 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=any(req.return_hidden_states for req in reqs),
chunked_req=chunked_req,
)
def batch_size(self):
@@ -1236,7 +1242,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def retract_decode(self, server_args: ServerArgs):
"""Retract the decoding requests when there is not enough memory."""
sorted_indices = [i for i in range(len(self.reqs))]
sorted_indices = list(range(len(self.reqs)))
# TODO(lsyin): improve retraction policy for radix cache
# For spec decoding, filter_batch API can only filter
@@ -1413,15 +1419,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
def filter_batch(
self,
chunked_req_to_exclude: Optional[Req] = None,
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
keep_indices: Optional[List[int]] = None,
):
if keep_indices is None:
if isinstance(chunked_req_to_exclude, Req):
chunked_req_to_exclude = [chunked_req_to_exclude]
elif chunked_req_to_exclude is None:
chunked_req_to_exclude = []
keep_indices = [
i
for i in range(len(self.reqs))
if not self.reqs[i].finished()
and self.reqs[i] is not chunked_req_to_exclude
and not self.reqs[i] in chunked_req_to_exclude
]
if keep_indices is None or len(keep_indices) == 0:

View File

@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.utils import (
ReqToMetadataIdxAllocator,
TransferBackend,
)
from sglang.srt.distributed import get_pp_group, get_world_group
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -114,7 +115,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.reasoning_parser import ReasoningParser
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -127,6 +132,7 @@ from sglang.srt.utils import (
get_bool_env_var,
get_zmq_socket,
kill_itself_when_parent_died,
point_to_point_pyobj,
pyspy_dump_schedulers,
set_gpu_proc_affinity,
set_random_seed,
@@ -145,8 +151,9 @@ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
@dataclass
class GenerationBatchResult:
logits_output: LogitsProcessorOutput
next_token_ids: List[int]
logits_output: Optional[LogitsProcessorOutput]
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
next_token_ids: Optional[List[int]]
extend_input_len_per_req: List[int]
extend_logprob_start_len_per_req: List[int]
bid: int
@@ -171,12 +178,16 @@ class Scheduler(
port_args: PortArgs,
gpu_id: int,
tp_rank: int,
pp_rank: int,
dp_rank: Optional[int],
):
# Parse args
self.server_args = server_args
self.tp_rank = tp_rank
self.pp_rank = pp_rank
self.tp_size = server_args.tp_size
self.pp_size = server_args.pp_size
self.dp_size = server_args.dp_size
self.schedule_policy = server_args.schedule_policy
self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch
@@ -192,7 +203,6 @@ class Scheduler(
self.page_size = server_args.page_size
# Distributed rank info
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
@@ -204,7 +214,7 @@ class Scheduler(
# Init inter-process communication
context = zmq.Context(2)
if self.attn_tp_rank == 0:
if self.pp_rank == 0 and self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
)
@@ -259,6 +269,7 @@ class Scheduler(
server_args=server_args,
gpu_id=gpu_id,
tp_rank=tp_rank,
pp_rank=pp_rank,
dp_rank=dp_rank,
nccl_port=port_args.nccl_port,
)
@@ -292,8 +303,18 @@ class Scheduler(
_,
_,
) = self.tp_worker.get_worker_info()
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
if global_server_args_dict["max_micro_batch_size"] is None:
global_server_args_dict["max_micro_batch_size"] = max(
self.max_running_requests // server_args.pp_size, 1
)
self.tp_group = self.tp_worker.get_tp_group()
self.tp_cpu_group = self.tp_group.cpu_group
self.attn_tp_group = self.tp_worker.get_attention_tp_group()
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
self.pp_group = get_pp_group()
self.world_group = get_world_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
@@ -673,26 +694,141 @@ class Scheduler(
self.last_batch = batch
@DynamicGradMode()
def event_loop_pp(self):
"""A non-overlap scheduler loop for pipeline parallelism."""
mbs = [None] * self.pp_size
last_mbs = [None] * self.pp_size
self.running_mbs = [
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
]
bids = [None] * self.pp_size
pp_outputs: Optional[PPProxyTensors] = None
while True:
server_is_idle = True
for mb_id in range(self.pp_size):
self.running_batch = self.running_mbs[mb_id]
self.last_batch = last_mbs[mb_id]
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
mbs[mb_id] = self.get_next_batch_to_run()
self.running_mbs[mb_id] = self.running_batch
self.cur_batch = mbs[mb_id]
if self.cur_batch:
server_is_idle = False
result = self.run_batch(self.cur_batch)
# send the outputs to the next step
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids, bids[mb_id] = (
result.next_token_ids,
result.bid,
)
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
}
)
# send the output from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
# receive outputs and post-process (filter finished reqs) the coming microbatch
next_mb_id = (mb_id + 1) % self.pp_size
next_pp_outputs = None
if mbs[next_mb_id] is not None:
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
self.pp_group.recv_tensor_dict(
all_gather_group=self.attn_tp_group
)
)
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
output_result = GenerationBatchResult(
logits_output=None,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=None,
extend_logprob_start_len_per_req=None,
bid=bids[next_mb_id],
)
self.process_batch_result(mbs[next_mb_id], output_result)
last_mbs[next_mb_id] = mbs[next_mb_id]
# carry the outputs to the next stage
if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
if pp_outputs:
# send the outputs from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
if not self.pp_group.is_last_rank:
# send out reqs to the next stage
dp_offset = self.dp_rank * self.attn_tp_size
if self.attn_tp_rank == 0:
point_to_point_pyobj(
recv_reqs,
self.pp_rank * self.tp_size + dp_offset,
self.world_group.cpu_group,
self.pp_rank * self.tp_size + dp_offset,
(self.pp_rank + 1) * self.tp_size + dp_offset,
)
# send out proxy tensors to the next stage
if self.cur_batch:
self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors,
all_gather_group=self.attn_tp_group,
)
pp_outputs = next_pp_outputs
# When the server is idle, self-check and re-init some states
if server_is_idle:
self.check_memory()
self.new_token_ratio = self.init_new_token_ratio
def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
if self.attn_tp_rank == 0:
recv_reqs = []
if self.pp_rank == 0:
if self.attn_tp_rank == 0:
recv_reqs = []
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
while True:
try:
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
while True:
try:
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_rpc)
while True:
try:
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_rpc)
else:
recv_reqs = None
else:
recv_reqs = None
if self.attn_tp_rank == 0:
dp_offset = self.dp_rank * self.attn_tp_size
recv_reqs = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset,
self.world_group.cpu_group,
(self.pp_rank - 1) * self.tp_size + dp_offset,
self.pp_rank * self.tp_size + dp_offset,
)
else:
recv_reqs = None
if self.server_args.enable_dp_attention:
if self.attn_tp_rank == 0:
@@ -715,20 +851,27 @@ class Scheduler(
control_reqs = None
if self.attn_tp_size != 1:
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
work_reqs = broadcast_pyobj(
work_reqs,
self.attn_tp_rank,
self.attn_tp_group.rank,
self.attn_tp_cpu_group,
src=attn_tp_rank_0,
src=self.attn_tp_group.ranks[0],
)
if self.tp_size != 1:
control_reqs = broadcast_pyobj(
control_reqs, self.tp_rank, self.tp_cpu_group
control_reqs,
self.tp_group.rank,
self.tp_cpu_group,
src=self.tp_group.ranks[0],
)
recv_reqs = work_reqs + control_reqs
elif self.tp_size != 1:
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
recv_reqs = broadcast_pyobj(
recv_reqs,
self.tp_group.rank,
self.tp_cpu_group,
src=self.tp_group.ranks[0],
)
return recv_reqs
def process_input_requests(self, recv_reqs: List):
@@ -1026,12 +1169,14 @@ class Scheduler(
self.metrics_collector.log_stats(self.stats)
def log_decode_stats(self):
def log_decode_stats(self, running_batch=None):
batch = running_batch or self.running_batch
gap_latency = time.time() - self.last_decode_stats_tic
self.last_decode_stats_tic = time.time()
self.last_gen_throughput = self.num_generated_tokens / gap_latency
self.num_generated_tokens = 0
num_running_reqs = len(self.running_batch.reqs)
num_running_reqs = len(batch.reqs)
num_used = self.max_total_num_tokens - (
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
@@ -1131,19 +1276,25 @@ class Scheduler(
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
chunked_req_to_exclude = set()
if self.chunked_req:
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
chunked_req_to_exclude.add(self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.chunked_req:
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.running_batch.batch_is_full = False
if self.last_batch.chunked_req is not None:
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
# We need to discard it.
chunked_req_to_exclude.add(self.last_batch.chunked_req)
# Filter batch
last_bs = self.last_batch.batch_size()
self.last_batch.filter_batch()
self.last_batch.filter_batch(
chunked_req_to_exclude=list(chunked_req_to_exclude)
)
if self.last_batch.batch_size() < last_bs:
self.running_batch.batch_is_full = False
@@ -1173,6 +1324,12 @@ class Scheduler(
return ret
def get_num_allocatable_reqs(self, running_bs):
res = global_server_args_dict["max_micro_batch_size"] - running_bs
if self.pp_size > 1:
res = min(res, self.req_to_token_pool.available_size())
return res
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
# Check if the grammar is ready in the grammar queue
if self.grammar_queue:
@@ -1185,7 +1342,12 @@ class Scheduler(
return None
running_bs = len(self.running_batch.reqs)
if running_bs >= self.max_running_requests:
# Igore the check if self.chunked_req is not None.
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
# as the space for the chunked request has just been released.
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
self.running_batch.batch_is_full = True
return None
@@ -1229,7 +1391,7 @@ class Scheduler(
self.running_batch.batch_is_full = True
break
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
self.running_batch.batch_is_full = True
break
@@ -1241,16 +1403,14 @@ class Scheduler(
res = adder.add_one_req(
req, self.chunked_req, self.enable_hierarchical_cache
)
if res != AddReqResult.CONTINUE:
if res == AddReqResult.NO_TOKEN:
if self.enable_hierarchical_cache:
# Set batch_is_full after making sure there are requests that can be served
self.running_batch.batch_is_full = len(
adder.can_run_list
) > 0 or (
self.running_batch is not None
and not self.running_batch.is_empty()
)
) > 0 or (not self.running_batch.is_empty())
else:
self.running_batch.batch_is_full = True
break
@@ -1293,6 +1453,7 @@ class Scheduler(
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
chunked_req=self.chunked_req,
)
new_batch.prepare_for_extend()
@@ -1370,9 +1531,14 @@ class Scheduler(
if self.is_generation:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
model_worker_batch
)
if self.pp_group.is_last_rank:
logits_output, next_token_ids = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
else:
pp_hidden_states_proxy_tensors, _ = (
self.tp_worker.forward_batch_generation(model_worker_batch)
)
bid = model_worker_batch.bid
else:
(
@@ -1386,7 +1552,9 @@ class Scheduler(
)
self.spec_num_total_forward_ct += batch.batch_size()
self.num_generated_tokens += num_accepted_tokens
batch.output_ids = next_token_ids
if self.pp_group.is_last_rank:
batch.output_ids = next_token_ids
# These 2 values are needed for processing the output, but the values can be
# modified by overlap schedule. So we have to copy them here so that
@@ -1401,8 +1569,13 @@ class Scheduler(
extend_logprob_start_len_per_req = None
ret = GenerationBatchResult(
logits_output=logits_output,
next_token_ids=next_token_ids,
logits_output=logits_output if self.pp_group.is_last_rank else None,
pp_hidden_states_proxy_tensors=(
pp_hidden_states_proxy_tensors
if not self.pp_group.is_last_rank
else None
),
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
extend_input_len_per_req=extend_input_len_per_req,
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
bid=bid,
@@ -1553,6 +1726,7 @@ class Scheduler(
def move_ready_grammar_requests(self):
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
num_ready_reqs = 0
for req in self.grammar_queue:
try:
@@ -1619,7 +1793,11 @@ class Scheduler(
def flush_cache(self):
"""Flush the memory pool and cache."""
if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
if (
len(self.waiting_queue) == 0
and self.running_batch.is_empty()
and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
):
self.cur_batch = None
self.last_batch = None
self.tree_cache.reset()
@@ -1657,7 +1835,6 @@ class Scheduler(
ret["avg_spec_accept_length"] = (
self.cum_spec_accept_length / self.cum_spec_accept_count
)
if RECORD_STEP_TIME:
ret["step_time_dict"] = self.step_time_dict
return GetInternalStateReqOutput(
@@ -1668,6 +1845,7 @@ class Scheduler(
server_args_dict = recv_req.server_args
args_allow_update = set(
[
"max_micro_batch_size",
"speculative_accept_threshold_single",
"speculative_accept_threshold_acc",
]
@@ -1678,6 +1856,14 @@ class Scheduler(
logging.warning(f"Updating {k} is not supported.")
if_success = False
break
elif k == "max_micro_batch_size" and (
v > self.max_running_requests // self.pp_size or v < 1
):
logging.warning(
f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
)
if_success = False
break
if if_success:
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
avg_spec_accept_length = (
@@ -1959,6 +2145,16 @@ class Scheduler(
else:
del self.sessions[session_id]
def get_print_prefix(self):
prefix = ""
if self.dp_rank is not None:
prefix += f" DP{self.dp_rank}"
if self.server_args.tp_size > 1:
prefix += f" TP{self.tp_rank}"
if self.pp_size > 1:
prefix += f" PP{self.pp_rank}"
return prefix
def is_health_check_generate_req(recv_req):
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
@@ -1983,14 +2179,18 @@ def run_scheduler_process(
port_args: PortArgs,
gpu_id: int,
tp_rank: int,
pp_rank: int,
dp_rank: Optional[int],
pipe_writer,
):
# Generate the prefix
if dp_rank is None:
prefix = f" TP{tp_rank}"
else:
prefix = f" DP{dp_rank} TP{tp_rank}"
prefix = ""
if dp_rank is not None:
prefix += f" DP{dp_rank}"
if server_args.tp_size > 1:
prefix += f" TP{tp_rank}"
if server_args.pp_size > 1:
prefix += f" PP{pp_rank}"
# Config the process
kill_itself_when_parent_died()
@@ -2012,7 +2212,7 @@ def run_scheduler_process(
# Create a scheduler and run the event loop
try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
pipe_writer.send(
{
"status": "ready",
@@ -2023,7 +2223,9 @@ def run_scheduler_process(
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
if disaggregation_mode == DisaggregationMode.NULL:
if scheduler.enable_overlap:
if server_args.pp_size > 1:
scheduler.event_loop_pp()
elif scheduler.enable_overlap:
scheduler.event_loop_overlap()
else:
scheduler.event_loop_normal()
@@ -2032,6 +2234,7 @@ def run_scheduler_process(
scheduler.event_loop_overlap_disagg_prefill()
else:
scheduler.event_loop_normal_disagg_prefill()
elif disaggregation_mode == DisaggregationMode.DECODE:
if scheduler.enable_overlap:
scheduler.event_loop_overlap_disagg_decode()

View File

@@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin:
self.attn_tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
):
self.log_decode_stats()
self.log_decode_stats(running_batch=batch)
def add_input_logprob_return_values(
self: Scheduler,

View File

@@ -15,11 +15,12 @@
import logging
import threading
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
@@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
@@ -47,6 +48,7 @@ class TpModelWorker:
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
pp_rank: int,
dp_rank: Optional[int],
nccl_port: int,
is_draft_worker: bool = False,
@@ -54,7 +56,9 @@ class TpModelWorker:
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
):
# Parse args
self.tp_size = server_args.tp_size
self.tp_rank = tp_rank
self.pp_rank = pp_rank
# Init model and tokenizer
self.model_config = ModelConfig(
@@ -73,12 +77,15 @@ class TpModelWorker:
quantization=server_args.quantization,
is_draft_model=is_draft_worker,
)
self.model_runner = ModelRunner(
model_config=self.model_config,
mem_fraction_static=server_args.mem_fraction_static,
gpu_id=gpu_id,
tp_rank=tp_rank,
tp_size=server_args.tp_size,
pp_rank=pp_rank,
pp_size=server_args.pp_size,
nccl_port=nccl_port,
server_args=server_args,
is_draft_worker=is_draft_worker,
@@ -105,6 +112,10 @@ class TpModelWorker:
)
self.device = self.model_runner.device
# Init nccl groups
self.pp_group = get_pp_group()
self.world_group = get_world_group()
# Profile number of tokens
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
self.max_prefill_tokens = server_args.max_prefill_tokens
@@ -130,8 +141,9 @@ class TpModelWorker:
# Sync random seed across TP workers
self.random_seed = broadcast_pyobj(
[server_args.random_seed],
self.tp_rank,
self.model_runner.tp_group.cpu_group,
self.tp_size * self.pp_rank + tp_rank,
self.world_group.cpu_group,
src=self.world_group.ranks[0],
)[0]
set_random_seed(self.random_seed)
@@ -156,11 +168,14 @@ class TpModelWorker:
def get_pad_input_ids_func(self):
return getattr(self.model_runner.model, "pad_input_ids", None)
def get_tp_cpu_group(self):
return self.model_runner.tp_group.cpu_group
def get_tp_group(self):
return self.model_runner.tp_group
def get_attention_tp_group(self):
return self.model_runner.attention_tp_group
def get_attention_tp_cpu_group(self):
return self.model_runner.attention_tp_group.cpu_group
return getattr(self.model_runner.attention_tp_group, "cpu_group", None)
def get_memory_pool(self):
return (
@@ -172,19 +187,38 @@ class TpModelWorker:
self,
model_worker_batch: ModelWorkerBatch,
skip_sample: bool = False,
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
if model_worker_batch.launch_done is not None:
model_worker_batch.launch_done.set()
pp_proxy_tensors = None
if not self.pp_group.is_first_rank:
pp_proxy_tensors = PPProxyTensors(
self.pp_group.recv_tensor_dict(
all_gather_group=self.get_attention_tp_group()
)
)
if skip_sample:
next_token_ids = None
if self.pp_group.is_last_rank:
logits_output = self.model_runner.forward(
forward_batch, pp_proxy_tensors=pp_proxy_tensors
)
if model_worker_batch.launch_done is not None:
model_worker_batch.launch_done.set()
if skip_sample:
next_token_ids = None
else:
next_token_ids = self.model_runner.sample(
logits_output, model_worker_batch
)
return logits_output, next_token_ids
else:
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
return logits_output, next_token_ids
pp_proxy_tensors = self.model_runner.forward(
forward_batch,
pp_proxy_tensors=pp_proxy_tensors,
)
return pp_proxy_tensors.tensors, None
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)

View File

@@ -56,11 +56,14 @@ class TpModelWorkerClient:
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
pp_rank: int,
dp_rank: Optional[int],
nccl_port: int,
):
# Load the model
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
self.worker = TpModelWorker(
server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port
)
self.max_running_requests = self.worker.max_running_requests
self.device = self.worker.device
self.gpu_id = gpu_id
@@ -91,8 +94,11 @@ class TpModelWorkerClient:
def get_pad_input_ids_func(self):
return self.worker.get_pad_input_ids_func()
def get_tp_cpu_group(self):
return self.worker.get_tp_cpu_group()
def get_tp_group(self):
return self.worker.get_tp_group()
def get_attention_tp_group(self):
return self.worker.get_attention_tp_group()
def get_attention_tp_cpu_group(self):
return self.worker.get_attention_tp_cpu_group()