Files
sglang/python/sglang/srt/disaggregation/prefill.py
2025-09-28 00:09:38 +08:00

885 lines
36 KiB
Python

"""
Life cycle of a request in the prefill server
1. Bootstrap Queue
a. Initialize a sender for each request
b. Use the queue to store requests whose bootstrap (handshake and preallocation) has not finished
c. Poll senders to check bootstrap state
d. Once bootstrap is complete, move request to Waiting Queue
2. Waiting Queue
a. Use PrefillAdder to pop requests
b. Run forward
c. Add the request to Inflight Queue
3. Inflight Queue
a. Poll (non-blocking) the sender of the request
b. Once the transfer has finished, return the request
"""
from __future__ import annotations
import logging
import threading
from collections import deque
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Type
import torch
from sglang.srt.disaggregation.base import BaseKVManager, KVPoll
from sglang.srt.disaggregation.utils import (
FAKE_BOOTSTRAP_HOST,
DisaggregationMode,
KVClassType,
MetadataBuffers,
ReqToMetadataIdxAllocator,
TransferBackend,
get_kv_class,
is_mla_backend,
kv_to_page_indices,
kv_to_page_num,
poll_and_all_reduce,
prepare_abort,
)
from sglang.srt.managers.schedule_batch import (
FINISH_LENGTH,
Req,
RequestStage,
ScheduleBatch,
)
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
from sglang.srt.utils import (
DynamicGradMode,
broadcast_pyobj,
point_to_point_pyobj,
require_mlp_sync,
)
if TYPE_CHECKING:
from torch.distributed import ProcessGroup
from sglang.srt.managers.scheduler import GenerationBatchResult, Scheduler
from sglang.srt.mem_cache.memory_pool import KVCache
logger = logging.getLogger(__name__)
class PrefillBootstrapQueue:
"""
Store the requests in bootstrapping
"""
def __init__(
self,
token_to_kv_pool: KVCache,
draft_token_to_kv_pool: Optional[KVCache],
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
metadata_buffers: MetadataBuffers,
tp_rank: int,
tp_size: int,
gpu_id: int,
bootstrap_port: int,
gloo_group: ProcessGroup,
max_total_num_tokens: int,
decode_tp_size: int,
decode_dp_size: int,
scheduler: Scheduler,
pp_rank: int,
pp_size: int,
transfer_backend: TransferBackend,
):
self.token_to_kv_pool = token_to_kv_pool
self.draft_token_to_kv_pool = draft_token_to_kv_pool
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
self.metadata_buffers = metadata_buffers
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
self.tp_rank = tp_rank
self.tp_size = tp_size
self.decode_tp_size = decode_tp_size
self.decode_dp_size = decode_dp_size
self.pp_rank = pp_rank
self.pp_size = pp_size
self.gpu_id = gpu_id
self.bootstrap_port = bootstrap_port
self.queue: List[Req] = []
self.gloo_group = gloo_group
self.max_total_num_tokens = max_total_num_tokens
self.scheduler = scheduler
self.transfer_backend = transfer_backend
self.kv_manager = self._init_kv_manager()
def _init_kv_manager(self) -> BaseKVManager:
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
kv_args = kv_args_class()
kv_args.engine_rank = self.tp_rank
kv_args.pp_rank = self.pp_rank
kv_args.system_dp_rank = self.scheduler.dp_rank
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
kv_args.prefill_pp_size = self.pp_size
kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer
kv_data_ptrs, kv_data_lens, kv_item_lens = (
self.token_to_kv_pool.get_contiguous_buf_infos()
)
if self.draft_token_to_kv_pool is not None:
# We should also transfer draft model kv cache. The indices are
# always shared with a target model.
draft_kv_data_ptrs, draft_kv_data_lens, draft_kv_item_lens = (
self.draft_token_to_kv_pool.get_contiguous_buf_infos()
)
kv_data_ptrs += draft_kv_data_ptrs
kv_data_lens += draft_kv_data_lens
kv_item_lens += draft_kv_item_lens
kv_args.kv_data_ptrs = kv_data_ptrs
kv_args.kv_data_lens = kv_data_lens
kv_args.kv_item_lens = kv_item_lens
if not self.is_mla_backend:
kv_args.kv_head_num = self.token_to_kv_pool.head_num
kv_args.page_size = self.token_to_kv_pool.page_size
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
self.metadata_buffers.get_buf_infos()
)
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
kv_args.gpu_id = self.scheduler.gpu_id
kv_manager_class: Type[BaseKVManager] = get_kv_class(
self.transfer_backend, KVClassType.MANAGER
)
kv_manager: BaseKVManager = kv_manager_class(
kv_args,
DisaggregationMode.PREFILL,
self.scheduler.server_args,
self.is_mla_backend,
)
return kv_manager
def add(self, req: Req, num_kv_heads: int) -> None:
if self._check_if_req_exceed_kv_capacity(req):
return
if req.bootstrap_host == FAKE_BOOTSTRAP_HOST:
kv_sender_class = get_kv_class(TransferBackend.FAKE, KVClassType.SENDER)
else:
kv_sender_class = get_kv_class(self.transfer_backend, KVClassType.SENDER)
dest_tp_ranks = [self.tp_rank]
req.disagg_kv_sender = kv_sender_class(
mgr=self.kv_manager,
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
bootstrap_room=req.bootstrap_room,
dest_tp_ranks=dest_tp_ranks,
pp_rank=self.pp_rank,
)
self._process_req(req)
req.add_latency(RequestStage.PREFILL_PREPARE)
self.queue.append(req)
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
for req in reqs:
self.add(req, num_kv_heads)
def _check_if_req_exceed_kv_capacity(self, req: Req) -> bool:
if len(req.origin_input_ids) > self.max_total_num_tokens:
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
logger.error(message)
prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST)
self.scheduler.stream_output([req], req.return_logprob)
return True
return False
def _process_req(self, req: Req) -> None:
"""
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
"""
req.sampling_params.max_new_tokens = 1
def pop_bootstrapped(
self,
return_failed_reqs: bool = False,
rids_to_check: Optional[List[str]] = None,
) -> List[Req]:
"""
pop the reqs which has finished bootstrapping
return_failed_reqs: For PP, on rank 0, also return the failed reqs to notify the next rank
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
"""
bootstrapped_reqs = []
failed_reqs = []
indices_to_remove = set()
if len(self.queue) == 0:
if return_failed_reqs is False:
return []
else:
return [], []
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.queue], self.gloo_group
)
for i, (req, poll) in enumerate(zip(self.queue, polls)):
if rids_to_check is not None:
# if req not in reqs_info_to_check, skip
if req.rid not in rids_to_check:
continue
# Either waiting for input or failed
assert poll == KVPoll.WaitingForInput or poll == KVPoll.Failed
if poll == KVPoll.Bootstrapping:
continue
elif poll == KVPoll.Failed:
error_message = f"Prefill bootstrap failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
try:
req.disagg_kv_sender.failure_exception()
except Exception as e:
error_message += f" with exception {e}"
logger.error(error_message)
prepare_abort(
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
)
self.scheduler.stream_output([req], req.return_logprob)
indices_to_remove.add(i)
failed_reqs.append(req)
if self.scheduler.enable_metrics:
self.scheduler.metrics_collector.increment_bootstrap_failed_reqs()
continue
# KV.WaitingForInput - init here
num_kv_indices = len(req.origin_input_ids)
if self.req_to_metadata_buffer_idx_allocator.available_size() == 0:
break
req.metadata_buffer_index = (
self.req_to_metadata_buffer_idx_allocator.alloc()
)
assert req.metadata_buffer_index is not None
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
bootstrapped_reqs.append(req)
indices_to_remove.add(i)
self.queue = [
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
]
if return_failed_reqs is False:
return bootstrapped_reqs
else:
return bootstrapped_reqs, failed_reqs
class SchedulerDisaggregationPrefillMixin:
"""
Mixin for Scheduler to handle disaggregation prefill
"""
@torch.no_grad()
def event_loop_normal_disagg_prefill(self: Scheduler) -> None:
"""A normal scheduler loop for prefill worker in disaggregation mode."""
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
if require_mlp_sync(self.server_args):
batch = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.process_batch_result_disagg_prefill(batch, result)
if len(self.disagg_prefill_inflight_queue) > 0:
self.process_disagg_prefill_inflight_queue()
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
self.self_check_during_idle()
self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self.running_batch.batch_is_full = False
@torch.no_grad()
def event_loop_overlap_disagg_prefill(self: Scheduler) -> None:
self.result_queue = deque()
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
self.waiting_queue.extend(
self.disagg_prefill_bootstrap_queue.pop_bootstrapped()
)
self.process_prefill_chunk()
batch = self.get_new_batch_prefill()
if require_mlp_sync(self.server_args):
batch = self.prepare_mlp_sync_batch(batch)
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.set_next_batch_sampling_info_done(tmp_batch)
if self.last_batch:
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
if len(self.disagg_prefill_inflight_queue) > 0:
self.process_disagg_prefill_inflight_queue()
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
self.self_check_during_idle()
self.last_batch = batch
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
# Otherwise, it hangs under high concurrency
self.running_batch.batch_is_full = False
def process_batch_result_disagg_prefill(
self: Scheduler,
batch: ScheduleBatch,
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
) -> None:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
Adapted from process_batch_result_prefill
"""
(
logits_output,
next_token_ids,
extend_input_len_per_req,
extend_logprob_start_len_per_req,
) = (
result.logits_output,
result.next_token_ids,
result.extend_input_len_per_req,
result.extend_logprob_start_len_per_req,
)
logprob_pt = 0
# Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
if self.enable_overlap:
# wait
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
launch_done
)
else:
next_token_ids = result.next_token_ids.tolist()
if batch.return_logprob:
if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
logits_output.next_token_logprobs.tolist()
)
if logits_output.input_token_logprobs is not None:
logits_output.input_token_logprobs = tuple(
logits_output.input_token_logprobs.tolist()
)
hidden_state_offset = 0
for i, (req, next_token_id) in enumerate(
zip(batch.reqs, next_token_ids, strict=True)
):
req: Req
if req.is_chunked <= 0:
# There is no output_ids for prefill
req.output_ids.append(next_token_id)
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
req.add_latency(RequestStage.PREFILL_FORWARD)
self.disagg_prefill_inflight_queue.append(req)
if (
logits_output is not None
and logits_output.hidden_states is not None
):
last_hidden_index = (
hidden_state_offset + extend_input_len_per_req[i] - 1
)
req.output_topk_p = batch.spec_info.topk_p[i]
req.output_topk_index = batch.spec_info.topk_index[i]
if self.spec_algorithm.is_eagle3():
req.hidden_states_tensor = (
batch.spec_info.hidden_states[i].cpu().clone()
)
else:
req.hidden_states_tensor = (
logits_output.hidden_states[last_hidden_index].cpu().clone()
)
hidden_state_offset += extend_input_len_per_req[i]
else:
req.hidden_states_tensor = None
if req.return_logprob:
assert extend_logprob_start_len_per_req is not None
assert extend_input_len_per_req is not None
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_logprob_return_values(
i,
req,
logprob_pt,
next_token_ids,
num_input_logprobs,
logits_output,
)
logprob_pt += num_input_logprobs
self.send_kv_chunk(req, last_chunk=True)
if req.grammar is not None:
# FIXME: this try-except block is for handling unexpected xgrammar issue.
try:
req.grammar.accept_token(next_token_id)
except ValueError as e:
# Grammar accept_token can raise ValueError if the token is not in the grammar.
# This can happen if the grammar is not set correctly or the token is invalid.
error_message = f"Grammar accept_token failed for req {req.rid} with token {next_token_id}: {e}"
self.tree_cache.cache_finished_req(req)
prepare_abort(
req,
error_message,
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
)
req.grammar.finished = req.finished()
else:
# being chunked reqs' prefill is not finished
req.is_chunked -= 1
if req.return_logprob:
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
extend_input_len = extend_input_len_per_req[i]
if extend_logprob_start_len < extend_input_len:
# Update input logprobs.
num_input_logprobs = extend_input_len - extend_logprob_start_len
self.add_input_logprob_return_values(
i,
req,
logits_output,
logprob_pt,
num_input_logprobs,
last_prefill_chunk=False,
)
logprob_pt += num_input_logprobs
if self.enable_overlap:
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
# We need to remove the sync in the following function for overlap schedule.
self.set_next_batch_sampling_info_done(batch)
self.maybe_send_health_check_signal()
def process_disagg_prefill_inflight_queue(
self: Scheduler, rids_to_check: Optional[List[str]] = None
) -> List[Req]:
"""
Poll the requests in the middle of transfer. If done, return the request.
rids_to_check: For PP, on rank > 0, check the rids from the previous rank has consensus with the current rank.
"""
if len(self.disagg_prefill_inflight_queue) == 0:
return []
done_reqs = []
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
self.attn_tp_cpu_group,
)
undone_reqs: List[Req] = []
# Check .poll() for the reqs in disagg_prefill_inflight_queue. If Success, respond to the client and remove it from the queue
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
if rids_to_check is not None:
if req.rid not in rids_to_check:
undone_reqs.append(req)
continue
assert poll == KVPoll.Success or poll == KVPoll.Failed
if poll in [KVPoll.WaitingForInput, KVPoll.Transferring]:
undone_reqs.append(req)
elif poll == KVPoll.Success: # transfer done
self.tree_cache.cache_finished_req(req) # unlock the tree
req.finished_reason = FINISH_LENGTH(length=0)
# FIXME: clean up req's data in transfer engine
if hasattr(req.disagg_kv_sender, "clear"):
req.disagg_kv_sender.clear()
done_reqs.append(req)
elif poll == KVPoll.Failed:
error_message = f"Prefill transfer failed for request rank={self.tp_rank} {req.rid=} {req.bootstrap_room=}"
try:
req.disagg_kv_sender.failure_exception()
except Exception as e:
error_message += f" with exception {e}"
logger.warning(error_message)
self.tree_cache.cache_finished_req(req) # unlock the tree
prepare_abort(
req, error_message, status_code=HTTPStatus.INTERNAL_SERVER_ERROR
)
done_reqs.append(req)
if self.enable_metrics:
self.metrics_collector.increment_transfer_failed_reqs()
else:
assert False, f"Unexpected polling state {poll=}"
# Stream requests which have finished transfer
self.stream_output(
done_reqs,
any(req.return_logprob for req in done_reqs),
None,
)
for req in done_reqs:
req: Req
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
req.metadata_buffer_index = -1
self.disagg_prefill_inflight_queue = undone_reqs
return done_reqs
def get_transferred_rids(self: Scheduler) -> List[str]:
"""
Used by PP, get the transferred rids but **do not pop**
"""
polls = poll_and_all_reduce(
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
self.tp_worker.get_tp_group().cpu_group,
)
transferred_rids: List[str] = []
for req, poll in zip(self.disagg_prefill_inflight_queue, polls):
if poll == KVPoll.Success or poll == KVPoll.Failed:
transferred_rids.append(req.rid)
return transferred_rids
def process_prefill_chunk(self: Scheduler) -> None:
if self.last_batch and self.last_batch.forward_mode.is_extend():
if self.chunked_req:
# Move the chunked request out of the batch so that we can merge
# only finished requests to running_batch.
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
if self.enable_overlap:
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
self.chunked_req.tmp_end_idx = min(
len(self.chunked_req.fill_ids),
len(self.chunked_req.origin_input_ids),
)
else:
self.send_kv_chunk(self.chunked_req)
# chunked request keeps its rid but will get a new req_pool_idx
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
self.running_batch.batch_is_full = False
def send_kv_chunk(
self: Scheduler,
req: Req,
last_chunk: bool = False,
end_idx: Optional[int] = None,
) -> None:
"""
Send a prefilled chunk to the decode server
"""
page_size = self.token_to_kv_pool_allocator.page_size
start_idx = req.start_send_idx
end_idx = (
end_idx
if end_idx is not None
else min(len(req.fill_ids), len(req.origin_input_ids))
)
if not last_chunk:
# if not the last chunk and the last page is partial, delay the last partial page to the next send
end_idx = end_idx - end_idx % page_size
kv_indices = (
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
.cpu()
.numpy()
)
req.start_send_idx = end_idx
if last_chunk:
self.disagg_metadata_buffers.set_buf(req)
page_indices = kv_to_page_indices(kv_indices, page_size)
if len(page_indices) == 0:
logger.info(
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
)
return
req.disagg_kv_sender.send(page_indices)
# PP
@DynamicGradMode()
def event_loop_pp_disagg_prefill(self: Scheduler):
"""
An event loop for the prefill server in pipeline parallelism.
Rules:
1. Each stage runs in the same order and is notified by the previous stage.
2. Each send/recv operation is blocking and matched by the neighboring stage.
Regular Schedule:
====================================================================
Stage i | Stage i+1
send ith req | recv ith req
send ith proxy | recv ith proxy
send prev (i+1)th carry | recv prev (i+1)th carry
====================================================================
Prefill Server Schedule:
====================================================================
Stage i | Stage i+1
send ith req | recv ith req
send ith bootstrap req | recv ith bootstrap req
send ith transferred req | recv ith transferred req
send ith proxy | recv ith proxy
send prev (i+1)th carry | recv prev (i+1)th carry
send prev (i+1)th release req | recv prev (i+1)th release req
====================================================================
There are two additional elements compared to the regular schedule:
1. Bootstrap Requests:
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
2. Transferred Requests + Release Requests:
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
"""
from sglang.srt.managers.scheduler import GenerationBatchResult
mbs = [None] * self.pp_size
last_mbs = [None] * self.pp_size
self.running_mbs = [
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
]
bids = [None] * self.pp_size
pp_outputs: Optional[PPProxyTensors] = None
# Either success or failed
bootstrapped_rids: List[str] = []
transferred_rids: List[str] = []
release_rids: Optional[List[str]] = None
# transferred microbatch
tmbs = [None] * self.pp_size
ENABLE_RELEASE = True # For debug
while True:
server_is_idle = True
for mb_id in range(self.pp_size):
self.running_batch = self.running_mbs[mb_id]
self.last_batch = last_mbs[mb_id]
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
if self.pp_group.is_first_rank:
# First rank, pop the bootstrap reqs from the bootstrap queue
bootstrapped_reqs, failed_reqs = (
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
return_failed_reqs=True
)
)
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
req.rid for req in failed_reqs
]
self.waiting_queue.extend(bootstrapped_reqs)
else:
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
bootstrapped_reqs = (
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
rids_to_check=bootstrapped_rids
)
)
self.waiting_queue.extend(bootstrapped_reqs)
if self.pp_group.is_first_rank:
transferred_rids = self.get_transferred_rids()
# if other ranks,
else:
# 1. recv previous stage's transferred reqs info
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
# 2. get the current stage's transferred reqs info
curr_transferred_rids = self.get_transferred_rids()
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
transferred_rids = list(
set(prev_transferred_rids) & set(curr_transferred_rids)
)
tmbs[mb_id] = transferred_rids
self.process_prefill_chunk()
mbs[mb_id] = self.get_new_batch_prefill()
self.running_mbs[mb_id] = self.running_batch
self.cur_batch = mbs[mb_id]
if self.cur_batch:
server_is_idle = False
result = self.run_batch(self.cur_batch)
# send the outputs to the next step
if self.pp_group.is_last_rank:
if self.cur_batch:
next_token_ids, bids[mb_id] = (
result.next_token_ids,
result.bid,
)
pp_outputs = PPProxyTensors(
{
"next_token_ids": next_token_ids,
}
)
# send the output from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
if ENABLE_RELEASE:
if self.pp_group.is_last_rank:
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
release_rids = transferred_rids
# send to the first rank
self.send_pyobj_to_next_stage(release_rids)
# receive outputs and post-process (filter finished reqs) the coming microbatch
next_mb_id = (mb_id + 1) % self.pp_size
next_pp_outputs = None
next_release_rids = None
if mbs[next_mb_id] is not None:
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
self.pp_group.recv_tensor_dict(
all_gather_group=self.attn_tp_group
)
)
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
output_result = GenerationBatchResult(
logits_output=None,
pp_hidden_states_proxy_tensors=None,
next_token_ids=next_pp_outputs["next_token_ids"],
extend_input_len_per_req=None,
extend_logprob_start_len_per_req=None,
bid=bids[next_mb_id],
can_run_cuda_graph=result.can_run_cuda_graph,
)
self.process_batch_result_disagg_prefill(
mbs[next_mb_id], output_result
)
last_mbs[next_mb_id] = mbs[next_mb_id]
if ENABLE_RELEASE:
if tmbs[next_mb_id] is not None:
# recv consensus rids from the previous rank
next_release_rids = self.recv_pyobj_from_prev_stage()
self.process_disagg_prefill_inflight_queue(next_release_rids)
# carry the outputs to the next stage
if not self.pp_group.is_last_rank:
if self.cur_batch:
bids[mb_id] = result.bid
if pp_outputs:
# send the outputs from the last round to let the next stage worker run post processing
self.pp_group.send_tensor_dict(
pp_outputs.tensors,
all_gather_group=self.attn_tp_group,
)
if ENABLE_RELEASE:
if release_rids is not None:
self.send_pyobj_to_next_stage(release_rids)
if not self.pp_group.is_last_rank:
# send out reqs to the next stage
self.send_pyobj_to_next_stage(recv_reqs)
self.send_pyobj_to_next_stage(bootstrapped_rids)
self.send_pyobj_to_next_stage(transferred_rids)
# send out proxy tensors to the next stage
if self.cur_batch:
self.pp_group.send_tensor_dict(
result.pp_hidden_states_proxy_tensors,
all_gather_group=self.attn_tp_group,
)
pp_outputs = next_pp_outputs
release_rids = next_release_rids
self.running_batch.batch_is_full = False
if not ENABLE_RELEASE:
if len(self.disagg_prefill_inflight_queue) > 0:
self.process_disagg_prefill_inflight_queue()
# When the server is idle, self-check and re-init some states
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
self.check_memory()
self.check_tree_cache()
self.new_token_ratio = self.init_new_token_ratio
def send_pyobj_to_next_stage(self, data):
if self.attn_tp_rank == 0:
dp_offset = self.attn_dp_rank * self.attn_tp_size
point_to_point_pyobj(
data,
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
self.pp_rank * self.tp_size + dp_offset,
((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset,
)
def recv_pyobj_from_prev_stage(self):
if self.attn_tp_rank == 0:
dp_offset = self.attn_dp_rank * self.attn_tp_size
data = point_to_point_pyobj(
[],
self.pp_rank * self.tp_size + dp_offset,
self.world_group.device_group,
((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset,
self.pp_rank * self.tp_size + dp_offset,
)
else:
data = None
if self.tp_size != 1:
data = broadcast_pyobj(
data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0]
)
return data