[PP] Add pipeline parallelism (#5724)
This commit is contained in:
@@ -51,6 +51,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
)
|
||||
from sglang.srt.distributed import get_pp_group, get_world_group
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
@@ -114,7 +115,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
@@ -127,6 +132,7 @@ from sglang.srt.utils import (
|
||||
get_bool_env_var,
|
||||
get_zmq_socket,
|
||||
kill_itself_when_parent_died,
|
||||
point_to_point_pyobj,
|
||||
pyspy_dump_schedulers,
|
||||
set_gpu_proc_affinity,
|
||||
set_random_seed,
|
||||
@@ -145,8 +151,9 @@ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
||||
|
||||
@dataclass
|
||||
class GenerationBatchResult:
|
||||
logits_output: LogitsProcessorOutput
|
||||
next_token_ids: List[int]
|
||||
logits_output: Optional[LogitsProcessorOutput]
|
||||
pp_hidden_states_proxy_tensors: Optional[torch.Tensor]
|
||||
next_token_ids: Optional[List[int]]
|
||||
extend_input_len_per_req: List[int]
|
||||
extend_logprob_start_len_per_req: List[int]
|
||||
bid: int
|
||||
@@ -171,12 +178,16 @@ class Scheduler(
|
||||
port_args: PortArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
pp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
):
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
self.tp_rank = tp_rank
|
||||
self.pp_rank = pp_rank
|
||||
self.tp_size = server_args.tp_size
|
||||
self.pp_size = server_args.pp_size
|
||||
self.dp_size = server_args.dp_size
|
||||
self.schedule_policy = server_args.schedule_policy
|
||||
self.lora_paths = server_args.lora_paths
|
||||
self.max_loras_per_batch = server_args.max_loras_per_batch
|
||||
@@ -192,7 +203,6 @@ class Scheduler(
|
||||
self.page_size = server_args.page_size
|
||||
|
||||
# Distributed rank info
|
||||
self.dp_size = server_args.dp_size
|
||||
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
||||
compute_dp_attention_world_info(
|
||||
server_args.enable_dp_attention,
|
||||
@@ -204,7 +214,7 @@ class Scheduler(
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
if self.attn_tp_rank == 0:
|
||||
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
||||
)
|
||||
@@ -259,6 +269,7 @@ class Scheduler(
|
||||
server_args=server_args,
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
pp_rank=pp_rank,
|
||||
dp_rank=dp_rank,
|
||||
nccl_port=port_args.nccl_port,
|
||||
)
|
||||
@@ -292,8 +303,18 @@ class Scheduler(
|
||||
_,
|
||||
_,
|
||||
) = self.tp_worker.get_worker_info()
|
||||
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
||||
if global_server_args_dict["max_micro_batch_size"] is None:
|
||||
global_server_args_dict["max_micro_batch_size"] = max(
|
||||
self.max_running_requests // server_args.pp_size, 1
|
||||
)
|
||||
|
||||
self.tp_group = self.tp_worker.get_tp_group()
|
||||
self.tp_cpu_group = self.tp_group.cpu_group
|
||||
self.attn_tp_group = self.tp_worker.get_attention_tp_group()
|
||||
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
|
||||
self.pp_group = get_pp_group()
|
||||
self.world_group = get_world_group()
|
||||
|
||||
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
||||
global_server_args_dict.update(worker_global_server_args_dict)
|
||||
set_random_seed(self.random_seed)
|
||||
@@ -673,26 +694,141 @@ class Scheduler(
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
@DynamicGradMode()
|
||||
def event_loop_pp(self):
|
||||
"""A non-overlap scheduler loop for pipeline parallelism."""
|
||||
mbs = [None] * self.pp_size
|
||||
last_mbs = [None] * self.pp_size
|
||||
self.running_mbs = [
|
||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||
]
|
||||
bids = [None] * self.pp_size
|
||||
pp_outputs: Optional[PPProxyTensors] = None
|
||||
while True:
|
||||
server_is_idle = True
|
||||
for mb_id in range(self.pp_size):
|
||||
self.running_batch = self.running_mbs[mb_id]
|
||||
self.last_batch = last_mbs[mb_id]
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
mbs[mb_id] = self.get_next_batch_to_run()
|
||||
self.running_mbs[mb_id] = self.running_batch
|
||||
|
||||
self.cur_batch = mbs[mb_id]
|
||||
if self.cur_batch:
|
||||
server_is_idle = False
|
||||
result = self.run_batch(self.cur_batch)
|
||||
|
||||
# send the outputs to the next step
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
next_token_ids, bids[mb_id] = (
|
||||
result.next_token_ids,
|
||||
result.bid,
|
||||
)
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
}
|
||||
)
|
||||
# send the output from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
||||
next_mb_id = (mb_id + 1) % self.pp_size
|
||||
next_pp_outputs = None
|
||||
if mbs[next_mb_id] is not None:
|
||||
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
||||
self.pp_group.recv_tensor_dict(
|
||||
all_gather_group=self.attn_tp_group
|
||||
)
|
||||
)
|
||||
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||
output_result = GenerationBatchResult(
|
||||
logits_output=None,
|
||||
pp_hidden_states_proxy_tensors=None,
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=None,
|
||||
extend_logprob_start_len_per_req=None,
|
||||
bid=bids[next_mb_id],
|
||||
)
|
||||
self.process_batch_result(mbs[next_mb_id], output_result)
|
||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||
|
||||
# carry the outputs to the next stage
|
||||
if not self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
bids[mb_id] = result.bid
|
||||
if pp_outputs:
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
if not self.pp_group.is_last_rank:
|
||||
# send out reqs to the next stage
|
||||
dp_offset = self.dp_rank * self.attn_tp_size
|
||||
if self.attn_tp_rank == 0:
|
||||
point_to_point_pyobj(
|
||||
recv_reqs,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
self.world_group.cpu_group,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
||||
)
|
||||
|
||||
# send out proxy tensors to the next stage
|
||||
if self.cur_batch:
|
||||
self.pp_group.send_tensor_dict(
|
||||
result.pp_hidden_states_proxy_tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
pp_outputs = next_pp_outputs
|
||||
|
||||
# When the server is idle, self-check and re-init some states
|
||||
if server_is_idle:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
def recv_requests(self) -> List[Req]:
|
||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||
if self.attn_tp_rank == 0:
|
||||
recv_reqs = []
|
||||
if self.pp_rank == 0:
|
||||
if self.attn_tp_rank == 0:
|
||||
recv_reqs = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_req)
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_req)
|
||||
|
||||
while True:
|
||||
try:
|
||||
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_rpc)
|
||||
while True:
|
||||
try:
|
||||
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_rpc)
|
||||
else:
|
||||
recv_reqs = None
|
||||
else:
|
||||
recv_reqs = None
|
||||
if self.attn_tp_rank == 0:
|
||||
dp_offset = self.dp_rank * self.attn_tp_size
|
||||
recv_reqs = point_to_point_pyobj(
|
||||
[],
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
self.world_group.cpu_group,
|
||||
(self.pp_rank - 1) * self.tp_size + dp_offset,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
)
|
||||
else:
|
||||
recv_reqs = None
|
||||
|
||||
if self.server_args.enable_dp_attention:
|
||||
if self.attn_tp_rank == 0:
|
||||
@@ -715,20 +851,27 @@ class Scheduler(
|
||||
control_reqs = None
|
||||
|
||||
if self.attn_tp_size != 1:
|
||||
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
|
||||
work_reqs = broadcast_pyobj(
|
||||
work_reqs,
|
||||
self.attn_tp_rank,
|
||||
self.attn_tp_group.rank,
|
||||
self.attn_tp_cpu_group,
|
||||
src=attn_tp_rank_0,
|
||||
src=self.attn_tp_group.ranks[0],
|
||||
)
|
||||
if self.tp_size != 1:
|
||||
control_reqs = broadcast_pyobj(
|
||||
control_reqs, self.tp_rank, self.tp_cpu_group
|
||||
control_reqs,
|
||||
self.tp_group.rank,
|
||||
self.tp_cpu_group,
|
||||
src=self.tp_group.ranks[0],
|
||||
)
|
||||
recv_reqs = work_reqs + control_reqs
|
||||
elif self.tp_size != 1:
|
||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
||||
recv_reqs = broadcast_pyobj(
|
||||
recv_reqs,
|
||||
self.tp_group.rank,
|
||||
self.tp_cpu_group,
|
||||
src=self.tp_group.ranks[0],
|
||||
)
|
||||
return recv_reqs
|
||||
|
||||
def process_input_requests(self, recv_reqs: List):
|
||||
@@ -1026,12 +1169,14 @@ class Scheduler(
|
||||
|
||||
self.metrics_collector.log_stats(self.stats)
|
||||
|
||||
def log_decode_stats(self):
|
||||
def log_decode_stats(self, running_batch=None):
|
||||
batch = running_batch or self.running_batch
|
||||
|
||||
gap_latency = time.time() - self.last_decode_stats_tic
|
||||
self.last_decode_stats_tic = time.time()
|
||||
self.last_gen_throughput = self.num_generated_tokens / gap_latency
|
||||
self.num_generated_tokens = 0
|
||||
num_running_reqs = len(self.running_batch.reqs)
|
||||
num_running_reqs = len(batch.reqs)
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
@@ -1131,19 +1276,25 @@ class Scheduler(
|
||||
|
||||
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
|
||||
# Merge the prefill batch into the running batch
|
||||
chunked_req_to_exclude = set()
|
||||
if self.chunked_req:
|
||||
# Move the chunked request out of the batch so that we can merge
|
||||
# only finished requests to running_batch.
|
||||
chunked_req_to_exclude.add(self.chunked_req)
|
||||
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
||||
# chunked request keeps its rid but will get a new req_pool_idx
|
||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
||||
if self.chunked_req:
|
||||
# Move the chunked request out of the batch so that we can merge
|
||||
# only finished requests to running_batch.
|
||||
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
||||
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
||||
# chunked request keeps its rid but will get a new req_pool_idx
|
||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||
self.running_batch.batch_is_full = False
|
||||
if self.last_batch.chunked_req is not None:
|
||||
# In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
|
||||
# We need to discard it.
|
||||
chunked_req_to_exclude.add(self.last_batch.chunked_req)
|
||||
|
||||
# Filter batch
|
||||
last_bs = self.last_batch.batch_size()
|
||||
self.last_batch.filter_batch()
|
||||
self.last_batch.filter_batch(
|
||||
chunked_req_to_exclude=list(chunked_req_to_exclude)
|
||||
)
|
||||
if self.last_batch.batch_size() < last_bs:
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
@@ -1173,6 +1324,12 @@ class Scheduler(
|
||||
|
||||
return ret
|
||||
|
||||
def get_num_allocatable_reqs(self, running_bs):
|
||||
res = global_server_args_dict["max_micro_batch_size"] - running_bs
|
||||
if self.pp_size > 1:
|
||||
res = min(res, self.req_to_token_pool.available_size())
|
||||
return res
|
||||
|
||||
def get_new_batch_prefill(self) -> Optional[ScheduleBatch]:
|
||||
# Check if the grammar is ready in the grammar queue
|
||||
if self.grammar_queue:
|
||||
@@ -1185,7 +1342,12 @@ class Scheduler(
|
||||
return None
|
||||
|
||||
running_bs = len(self.running_batch.reqs)
|
||||
if running_bs >= self.max_running_requests:
|
||||
# Igore the check if self.chunked_req is not None.
|
||||
# In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0,
|
||||
# as the space for the chunked request has just been released.
|
||||
# In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict.
|
||||
# Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak.
|
||||
if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req:
|
||||
self.running_batch.batch_is_full = True
|
||||
return None
|
||||
|
||||
@@ -1229,7 +1391,7 @@ class Scheduler(
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
if running_bs + len(adder.can_run_list) >= self.max_running_requests:
|
||||
if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs):
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
|
||||
@@ -1241,16 +1403,14 @@ class Scheduler(
|
||||
res = adder.add_one_req(
|
||||
req, self.chunked_req, self.enable_hierarchical_cache
|
||||
)
|
||||
|
||||
if res != AddReqResult.CONTINUE:
|
||||
if res == AddReqResult.NO_TOKEN:
|
||||
if self.enable_hierarchical_cache:
|
||||
# Set batch_is_full after making sure there are requests that can be served
|
||||
self.running_batch.batch_is_full = len(
|
||||
adder.can_run_list
|
||||
) > 0 or (
|
||||
self.running_batch is not None
|
||||
and not self.running_batch.is_empty()
|
||||
)
|
||||
) > 0 or (not self.running_batch.is_empty())
|
||||
else:
|
||||
self.running_batch.batch_is_full = True
|
||||
break
|
||||
@@ -1293,6 +1453,7 @@ class Scheduler(
|
||||
self.enable_overlap,
|
||||
self.spec_algorithm,
|
||||
self.server_args.enable_custom_logit_processor,
|
||||
chunked_req=self.chunked_req,
|
||||
)
|
||||
new_batch.prepare_for_extend()
|
||||
|
||||
@@ -1370,9 +1531,14 @@ class Scheduler(
|
||||
if self.is_generation:
|
||||
if self.spec_algorithm.is_none():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
if self.pp_group.is_last_rank:
|
||||
logits_output, next_token_ids = (
|
||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||
)
|
||||
else:
|
||||
pp_hidden_states_proxy_tensors, _ = (
|
||||
self.tp_worker.forward_batch_generation(model_worker_batch)
|
||||
)
|
||||
bid = model_worker_batch.bid
|
||||
else:
|
||||
(
|
||||
@@ -1386,7 +1552,9 @@ class Scheduler(
|
||||
)
|
||||
self.spec_num_total_forward_ct += batch.batch_size()
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
batch.output_ids = next_token_ids
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
batch.output_ids = next_token_ids
|
||||
|
||||
# These 2 values are needed for processing the output, but the values can be
|
||||
# modified by overlap schedule. So we have to copy them here so that
|
||||
@@ -1401,8 +1569,13 @@ class Scheduler(
|
||||
extend_logprob_start_len_per_req = None
|
||||
|
||||
ret = GenerationBatchResult(
|
||||
logits_output=logits_output,
|
||||
next_token_ids=next_token_ids,
|
||||
logits_output=logits_output if self.pp_group.is_last_rank else None,
|
||||
pp_hidden_states_proxy_tensors=(
|
||||
pp_hidden_states_proxy_tensors
|
||||
if not self.pp_group.is_last_rank
|
||||
else None
|
||||
),
|
||||
next_token_ids=next_token_ids if self.pp_group.is_last_rank else None,
|
||||
extend_input_len_per_req=extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||
bid=bid,
|
||||
@@ -1553,6 +1726,7 @@ class Scheduler(
|
||||
|
||||
def move_ready_grammar_requests(self):
|
||||
"""Move requests whose grammar objects are ready from grammar_queue to waiting_queue."""
|
||||
|
||||
num_ready_reqs = 0
|
||||
for req in self.grammar_queue:
|
||||
try:
|
||||
@@ -1619,7 +1793,11 @@ class Scheduler(
|
||||
|
||||
def flush_cache(self):
|
||||
"""Flush the memory pool and cache."""
|
||||
if len(self.waiting_queue) == 0 and self.running_batch.is_empty():
|
||||
if (
|
||||
len(self.waiting_queue) == 0
|
||||
and self.running_batch.is_empty()
|
||||
and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs))
|
||||
):
|
||||
self.cur_batch = None
|
||||
self.last_batch = None
|
||||
self.tree_cache.reset()
|
||||
@@ -1657,7 +1835,6 @@ class Scheduler(
|
||||
ret["avg_spec_accept_length"] = (
|
||||
self.cum_spec_accept_length / self.cum_spec_accept_count
|
||||
)
|
||||
|
||||
if RECORD_STEP_TIME:
|
||||
ret["step_time_dict"] = self.step_time_dict
|
||||
return GetInternalStateReqOutput(
|
||||
@@ -1668,6 +1845,7 @@ class Scheduler(
|
||||
server_args_dict = recv_req.server_args
|
||||
args_allow_update = set(
|
||||
[
|
||||
"max_micro_batch_size",
|
||||
"speculative_accept_threshold_single",
|
||||
"speculative_accept_threshold_acc",
|
||||
]
|
||||
@@ -1678,6 +1856,14 @@ class Scheduler(
|
||||
logging.warning(f"Updating {k} is not supported.")
|
||||
if_success = False
|
||||
break
|
||||
elif k == "max_micro_batch_size" and (
|
||||
v > self.max_running_requests // self.pp_size or v < 1
|
||||
):
|
||||
logging.warning(
|
||||
f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]."
|
||||
)
|
||||
if_success = False
|
||||
break
|
||||
if if_success:
|
||||
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
||||
avg_spec_accept_length = (
|
||||
@@ -1959,6 +2145,16 @@ class Scheduler(
|
||||
else:
|
||||
del self.sessions[session_id]
|
||||
|
||||
def get_print_prefix(self):
|
||||
prefix = ""
|
||||
if self.dp_rank is not None:
|
||||
prefix += f" DP{self.dp_rank}"
|
||||
if self.server_args.tp_size > 1:
|
||||
prefix += f" TP{self.tp_rank}"
|
||||
if self.pp_size > 1:
|
||||
prefix += f" PP{self.pp_rank}"
|
||||
return prefix
|
||||
|
||||
|
||||
def is_health_check_generate_req(recv_req):
|
||||
return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK")
|
||||
@@ -1983,14 +2179,18 @@ def run_scheduler_process(
|
||||
port_args: PortArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
pp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
pipe_writer,
|
||||
):
|
||||
# Generate the prefix
|
||||
if dp_rank is None:
|
||||
prefix = f" TP{tp_rank}"
|
||||
else:
|
||||
prefix = f" DP{dp_rank} TP{tp_rank}"
|
||||
prefix = ""
|
||||
if dp_rank is not None:
|
||||
prefix += f" DP{dp_rank}"
|
||||
if server_args.tp_size > 1:
|
||||
prefix += f" TP{tp_rank}"
|
||||
if server_args.pp_size > 1:
|
||||
prefix += f" PP{pp_rank}"
|
||||
|
||||
# Config the process
|
||||
kill_itself_when_parent_died()
|
||||
@@ -2012,7 +2212,7 @@ def run_scheduler_process(
|
||||
|
||||
# Create a scheduler and run the event loop
|
||||
try:
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank)
|
||||
pipe_writer.send(
|
||||
{
|
||||
"status": "ready",
|
||||
@@ -2023,7 +2223,9 @@ def run_scheduler_process(
|
||||
disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode
|
||||
|
||||
if disaggregation_mode == DisaggregationMode.NULL:
|
||||
if scheduler.enable_overlap:
|
||||
if server_args.pp_size > 1:
|
||||
scheduler.event_loop_pp()
|
||||
elif scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap()
|
||||
else:
|
||||
scheduler.event_loop_normal()
|
||||
@@ -2032,6 +2234,7 @@ def run_scheduler_process(
|
||||
scheduler.event_loop_overlap_disagg_prefill()
|
||||
else:
|
||||
scheduler.event_loop_normal_disagg_prefill()
|
||||
|
||||
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap_disagg_decode()
|
||||
|
||||
Reference in New Issue
Block a user