[code move] move pp into a separate mixin (#11838)
This commit is contained in:
@@ -53,13 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
NSATokenToKVPool,
|
||||
SWAKVPool,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||
from sglang.srt.utils import (
|
||||
DynamicGradMode,
|
||||
broadcast_pyobj,
|
||||
point_to_point_pyobj,
|
||||
require_mlp_sync,
|
||||
)
|
||||
from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import ProcessGroup
|
||||
@@ -685,218 +679,6 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
return
|
||||
req.disagg_kv_sender.send(page_indices, state_indices)
|
||||
|
||||
# PP
|
||||
@DynamicGradMode()
|
||||
def event_loop_pp_disagg_prefill(self: Scheduler):
|
||||
"""
|
||||
An event loop for the prefill server in pipeline parallelism.
|
||||
|
||||
Rules:
|
||||
1. Each stage runs in the same order and is notified by the previous stage.
|
||||
2. Each send/recv operation is blocking and matched by the neighboring stage.
|
||||
|
||||
Regular Schedule:
|
||||
====================================================================
|
||||
Stage i | Stage i+1
|
||||
send ith req | recv ith req
|
||||
send ith proxy | recv ith proxy
|
||||
send prev (i+1)th carry | recv prev (i+1)th carry
|
||||
====================================================================
|
||||
|
||||
Prefill Server Schedule:
|
||||
====================================================================
|
||||
Stage i | Stage i+1
|
||||
send ith req | recv ith req
|
||||
send ith bootstrap req | recv ith bootstrap req
|
||||
send ith transferred req | recv ith transferred req
|
||||
send ith proxy | recv ith proxy
|
||||
send prev (i+1)th carry | recv prev (i+1)th carry
|
||||
send prev (i+1)th release req | recv prev (i+1)th release req
|
||||
====================================================================
|
||||
|
||||
There are two additional elements compared to the regular schedule:
|
||||
|
||||
1. Bootstrap Requests:
|
||||
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
|
||||
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
|
||||
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
|
||||
|
||||
2. Transferred Requests + Release Requests:
|
||||
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
|
||||
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
|
||||
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
|
||||
"""
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
|
||||
mbs = [None] * self.pp_size
|
||||
last_mbs = [None] * self.pp_size
|
||||
self.running_mbs = [
|
||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||
]
|
||||
pp_outputs: Optional[PPProxyTensors] = None
|
||||
|
||||
# Either success or failed
|
||||
bootstrapped_rids: List[str] = []
|
||||
transferred_rids: List[str] = []
|
||||
release_rids: Optional[List[str]] = None
|
||||
|
||||
# transferred microbatch
|
||||
tmbs = [None] * self.pp_size
|
||||
|
||||
ENABLE_RELEASE = True # For debug
|
||||
|
||||
while True:
|
||||
server_is_idle = True
|
||||
|
||||
for mb_id in range(self.pp_size):
|
||||
self.running_batch = self.running_mbs[mb_id]
|
||||
self.last_batch = last_mbs[mb_id]
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
if self.pp_group.is_first_rank:
|
||||
# First rank, pop the bootstrap reqs from the bootstrap queue
|
||||
bootstrapped_reqs, failed_reqs = (
|
||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
||||
return_failed_reqs=True
|
||||
)
|
||||
)
|
||||
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
|
||||
req.rid for req in failed_reqs
|
||||
]
|
||||
self.waiting_queue.extend(bootstrapped_reqs)
|
||||
else:
|
||||
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
|
||||
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
|
||||
bootstrapped_reqs = (
|
||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
||||
rids_to_check=bootstrapped_rids
|
||||
)
|
||||
)
|
||||
self.waiting_queue.extend(bootstrapped_reqs)
|
||||
|
||||
if self.pp_group.is_first_rank:
|
||||
transferred_rids = self.get_transferred_rids()
|
||||
# if other ranks,
|
||||
else:
|
||||
# 1. recv previous stage's transferred reqs info
|
||||
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
|
||||
# 2. get the current stage's transferred reqs info
|
||||
curr_transferred_rids = self.get_transferred_rids()
|
||||
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
|
||||
transferred_rids = list(
|
||||
set(prev_transferred_rids) & set(curr_transferred_rids)
|
||||
)
|
||||
|
||||
tmbs[mb_id] = transferred_rids
|
||||
|
||||
self.process_prefill_chunk()
|
||||
mbs[mb_id] = self.get_new_batch_prefill()
|
||||
self.running_mbs[mb_id] = self.running_batch
|
||||
|
||||
self.cur_batch = mbs[mb_id]
|
||||
if self.cur_batch:
|
||||
server_is_idle = False
|
||||
result = self.run_batch(self.cur_batch)
|
||||
|
||||
# send the outputs to the next step
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
next_token_ids = result.next_token_ids
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
}
|
||||
)
|
||||
# send the output from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
if ENABLE_RELEASE:
|
||||
if self.pp_group.is_last_rank:
|
||||
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
|
||||
release_rids = transferred_rids
|
||||
# send to the first rank
|
||||
self.send_pyobj_to_next_stage(release_rids)
|
||||
|
||||
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
||||
next_mb_id = (mb_id + 1) % self.pp_size
|
||||
next_pp_outputs = None
|
||||
next_release_rids = None
|
||||
|
||||
if mbs[next_mb_id] is not None:
|
||||
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
||||
self.pp_group.recv_tensor_dict(
|
||||
all_gather_group=self.attn_tp_group
|
||||
)
|
||||
)
|
||||
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||
output_result = GenerationBatchResult(
|
||||
logits_output=None,
|
||||
pp_hidden_states_proxy_tensors=None,
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=None,
|
||||
extend_logprob_start_len_per_req=None,
|
||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||
)
|
||||
self.process_batch_result_disagg_prefill(
|
||||
mbs[next_mb_id], output_result
|
||||
)
|
||||
|
||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||
|
||||
if ENABLE_RELEASE:
|
||||
if tmbs[next_mb_id] is not None:
|
||||
# recv consensus rids from the previous rank
|
||||
next_release_rids = self.recv_pyobj_from_prev_stage()
|
||||
self.process_disagg_prefill_inflight_queue(next_release_rids)
|
||||
|
||||
# carry the outputs to the next stage
|
||||
if not self.pp_group.is_last_rank:
|
||||
if pp_outputs:
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
if ENABLE_RELEASE:
|
||||
if release_rids is not None:
|
||||
self.send_pyobj_to_next_stage(release_rids)
|
||||
|
||||
if not self.pp_group.is_last_rank:
|
||||
# send out reqs to the next stage
|
||||
self.send_pyobj_to_next_stage(recv_reqs)
|
||||
self.send_pyobj_to_next_stage(bootstrapped_rids)
|
||||
self.send_pyobj_to_next_stage(transferred_rids)
|
||||
|
||||
# send out proxy tensors to the next stage
|
||||
if self.cur_batch:
|
||||
# FIXME(lsyin): remove this assert
|
||||
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
||||
self.pp_group.send_tensor_dict(
|
||||
result.pp_hidden_states_proxy_tensors.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
pp_outputs = next_pp_outputs
|
||||
release_rids = next_release_rids
|
||||
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
if not ENABLE_RELEASE:
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
self.process_disagg_prefill_inflight_queue()
|
||||
|
||||
# When the server is idle, self-check and re-init some states
|
||||
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
|
||||
self.check_memory()
|
||||
self.check_tree_cache()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
def send_pyobj_to_next_stage(self, data):
|
||||
if self.attn_tp_rank == 0:
|
||||
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||
|
||||
@@ -78,6 +78,7 @@ from sglang.srt.utils import flatten_nested_list
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
@@ -1527,8 +1528,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
if self.is_v2_eagle:
|
||||
# TODO(spec-v2): all v2 spec should go through this path
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
|
||||
draft_input: EagleDraftInput = self.spec_info
|
||||
draft_input.prepare_for_decode(self)
|
||||
|
||||
@@ -1585,8 +1584,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
def maybe_wait_verify_done(self):
|
||||
if self.is_v2_eagle:
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
|
||||
draft_input: EagleDraftInput = self.spec_info
|
||||
if draft_input.verify_done is not None:
|
||||
draft_input.verify_done.synchronize()
|
||||
|
||||
@@ -63,7 +63,6 @@ from sglang.srt.distributed import get_pp_group, get_world_group
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe import initialize_moe_config
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
@@ -114,7 +113,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import init_embedding_cache
|
||||
from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap
|
||||
from sglang.srt.managers.overlap_utils import FutureMap
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
FINISH_ABORT,
|
||||
ModelWorkerBatch,
|
||||
@@ -136,22 +135,21 @@ from sglang.srt.managers.scheduler_metrics_mixin import (
|
||||
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
||||
SchedulerOutputProcessorMixin,
|
||||
)
|
||||
from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin
|
||||
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
||||
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
|
||||
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
||||
SchedulerUpdateWeightsMixin,
|
||||
)
|
||||
from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.tracing.trace import (
|
||||
process_tracing_init,
|
||||
@@ -198,77 +196,6 @@ TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
||||
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerationBatchResult:
|
||||
logits_output: Optional[LogitsProcessorOutput] = None
|
||||
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
|
||||
next_token_ids: Optional[torch.Tensor] = None
|
||||
num_accepted_tokens: Optional[int] = None
|
||||
can_run_cuda_graph: bool = False
|
||||
|
||||
# For output processing
|
||||
extend_input_len_per_req: Optional[List[int]] = None
|
||||
extend_logprob_start_len_per_req: Optional[List[int]] = None
|
||||
|
||||
# For overlap scheduling
|
||||
copy_done: Optional[torch.cuda.Event] = None
|
||||
delay_sample_func: Optional[callable] = None
|
||||
future_indices: Optional[FutureIndices] = None
|
||||
|
||||
# FIXME(lsyin): maybe move to a better place?
|
||||
# sync path: forward stream -> output processor
|
||||
accept_lens: Optional[torch.Tensor] = None
|
||||
allocate_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# relay path: forward stream -> next step forward
|
||||
next_draft_input: Optional[EagleDraftInput] = None
|
||||
|
||||
def copy_to_cpu(self, return_logprob: bool = False):
|
||||
"""Copy tensors to CPU in overlap scheduling.
|
||||
Only the tensors which are needed for processing results are copied,
|
||||
e.g., next_token_ids, logits outputs
|
||||
"""
|
||||
if return_logprob:
|
||||
if self.logits_output.next_token_logits is not None:
|
||||
self.logits_output.next_token_logits = (
|
||||
self.logits_output.next_token_logits.to("cpu", non_blocking=True)
|
||||
)
|
||||
if self.logits_output.input_token_logprobs is not None:
|
||||
self.logits_output.input_token_logprobs = (
|
||||
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||
)
|
||||
if self.logits_output.hidden_states is not None:
|
||||
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
|
||||
|
||||
if self.accept_lens is not None:
|
||||
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
|
||||
|
||||
if self.allocate_lens is not None:
|
||||
self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
|
||||
|
||||
self.copy_done.record()
|
||||
|
||||
@classmethod
|
||||
def from_pp_proxy(
|
||||
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
||||
):
|
||||
# TODO(lsyin): refactor PP and avoid using dict
|
||||
proxy_dict = next_pp_outputs.tensors
|
||||
return cls(
|
||||
logits_output=logits_output,
|
||||
pp_hidden_states_proxy_tensors=None,
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
||||
extend_logprob_start_len_per_req=proxy_dict.get(
|
||||
"extend_logprob_start_len_per_req", None
|
||||
),
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingBatchResult:
|
||||
embeddings: torch.Tensor
|
||||
@@ -281,6 +208,7 @@ class Scheduler(
|
||||
SchedulerMetricsMixin,
|
||||
SchedulerDisaggregationDecodeMixin,
|
||||
SchedulerDisaggregationPrefillMixin,
|
||||
SchedulerPPMixin,
|
||||
):
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
|
||||
@@ -1058,128 +986,6 @@ class Scheduler(
|
||||
self.launch_batch_sample_if_needed(batch_result)
|
||||
self.last_batch = batch
|
||||
|
||||
@DynamicGradMode()
|
||||
def event_loop_pp(self):
|
||||
"""A non-overlap scheduler loop for pipeline parallelism."""
|
||||
mbs = [None] * self.pp_size
|
||||
last_mbs = [None] * self.pp_size
|
||||
self.running_mbs = [
|
||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||
]
|
||||
pp_outputs: Optional[PPProxyTensors] = None
|
||||
while True:
|
||||
server_is_idle = True
|
||||
for mb_id in range(self.pp_size):
|
||||
self.running_batch = self.running_mbs[mb_id]
|
||||
self.last_batch = last_mbs[mb_id]
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
mbs[mb_id] = self.get_next_batch_to_run()
|
||||
self.running_mbs[mb_id] = self.running_batch
|
||||
|
||||
self.cur_batch = mbs[mb_id]
|
||||
if self.cur_batch:
|
||||
server_is_idle = False
|
||||
result = self.run_batch(self.cur_batch)
|
||||
|
||||
# (last rank) send the outputs to the next step
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
next_token_ids = result.next_token_ids
|
||||
if self.cur_batch.return_logprob:
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
"extend_input_len_per_req": result.extend_input_len_per_req,
|
||||
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
||||
}
|
||||
| (
|
||||
{
|
||||
f"logits_output.{k}": v
|
||||
for k, v in result.logits_output.__dict__.items()
|
||||
}
|
||||
if result.logits_output is not None
|
||||
else {}
|
||||
)
|
||||
)
|
||||
else:
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
}
|
||||
)
|
||||
# send the output from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
||||
next_mb_id = (mb_id + 1) % self.pp_size
|
||||
next_pp_outputs = None
|
||||
if mbs[next_mb_id] is not None:
|
||||
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
||||
self.pp_group.recv_tensor_dict(
|
||||
all_gather_group=self.attn_tp_group
|
||||
)
|
||||
)
|
||||
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||
logits_output_args = {
|
||||
k[len("logits_output.") :]: v
|
||||
for k, v in next_pp_outputs.tensors.items()
|
||||
if k.startswith("logits_output.")
|
||||
}
|
||||
if len(logits_output_args) > 0:
|
||||
logits_output = LogitsProcessorOutput(**logits_output_args)
|
||||
else:
|
||||
logits_output = None
|
||||
|
||||
output_result = GenerationBatchResult.from_pp_proxy(
|
||||
logits_output=logits_output,
|
||||
next_pp_outputs=next_pp_outputs,
|
||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||
)
|
||||
self.process_batch_result(mbs[next_mb_id], output_result)
|
||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||
|
||||
# (not last rank)
|
||||
if not self.pp_group.is_last_rank:
|
||||
# carry the outputs to the next stage
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
if pp_outputs:
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
# send out reqs to the next stage
|
||||
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||
if self.attn_tp_rank == 0:
|
||||
point_to_point_pyobj(
|
||||
recv_reqs,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
self.world_group.device_group,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
||||
)
|
||||
|
||||
# send out proxy tensors to the next stage
|
||||
if self.cur_batch:
|
||||
# FIXME(lsyin): remove this assert
|
||||
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
||||
self.pp_group.send_tensor_dict(
|
||||
result.pp_hidden_states_proxy_tensors.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
pp_outputs = next_pp_outputs
|
||||
|
||||
# When the server is idle, self-check and re-init some states
|
||||
if server_is_idle:
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.self_check_during_idle()
|
||||
|
||||
def recv_requests(self) -> List[Req]:
|
||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||
|
||||
|
||||
341
python/sglang/srt/managers/scheduler_pp_mixin.py
Normal file
341
python/sglang/srt/managers/scheduler_pp_mixin.py
Normal file
@@ -0,0 +1,341 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.utils import GenerationBatchResult
|
||||
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||
from sglang.srt.utils import DynamicGradMode, point_to_point_pyobj
|
||||
|
||||
|
||||
class SchedulerPPMixin:
|
||||
|
||||
@DynamicGradMode()
|
||||
def event_loop_pp(self):
|
||||
"""A non-overlap scheduler loop for pipeline parallelism."""
|
||||
mbs = [None] * self.pp_size
|
||||
last_mbs = [None] * self.pp_size
|
||||
self.running_mbs = [
|
||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||
]
|
||||
pp_outputs: Optional[PPProxyTensors] = None
|
||||
while True:
|
||||
server_is_idle = True
|
||||
for mb_id in range(self.pp_size):
|
||||
self.running_batch = self.running_mbs[mb_id]
|
||||
self.last_batch = last_mbs[mb_id]
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
mbs[mb_id] = self.get_next_batch_to_run()
|
||||
self.running_mbs[mb_id] = self.running_batch
|
||||
|
||||
self.cur_batch = mbs[mb_id]
|
||||
if self.cur_batch:
|
||||
server_is_idle = False
|
||||
result = self.run_batch(self.cur_batch)
|
||||
|
||||
# (last rank) send the outputs to the next step
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
next_token_ids = result.next_token_ids
|
||||
if self.cur_batch.return_logprob:
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
"extend_input_len_per_req": result.extend_input_len_per_req,
|
||||
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
||||
}
|
||||
| (
|
||||
{
|
||||
f"logits_output.{k}": v
|
||||
for k, v in result.logits_output.__dict__.items()
|
||||
}
|
||||
if result.logits_output is not None
|
||||
else {}
|
||||
)
|
||||
)
|
||||
else:
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
}
|
||||
)
|
||||
# send the output from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
||||
next_mb_id = (mb_id + 1) % self.pp_size
|
||||
next_pp_outputs = None
|
||||
if mbs[next_mb_id] is not None:
|
||||
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
||||
self.pp_group.recv_tensor_dict(
|
||||
all_gather_group=self.attn_tp_group
|
||||
)
|
||||
)
|
||||
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||
logits_output_args = {
|
||||
k[len("logits_output.") :]: v
|
||||
for k, v in next_pp_outputs.tensors.items()
|
||||
if k.startswith("logits_output.")
|
||||
}
|
||||
if len(logits_output_args) > 0:
|
||||
logits_output = LogitsProcessorOutput(**logits_output_args)
|
||||
else:
|
||||
logits_output = None
|
||||
|
||||
output_result = GenerationBatchResult.from_pp_proxy(
|
||||
logits_output=logits_output,
|
||||
next_pp_outputs=next_pp_outputs,
|
||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||
)
|
||||
self.process_batch_result(mbs[next_mb_id], output_result)
|
||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||
|
||||
# (not last rank)
|
||||
if not self.pp_group.is_last_rank:
|
||||
# carry the outputs to the next stage
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
if pp_outputs:
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
# send out reqs to the next stage
|
||||
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||
if self.attn_tp_rank == 0:
|
||||
point_to_point_pyobj(
|
||||
recv_reqs,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
self.world_group.device_group,
|
||||
self.pp_rank * self.tp_size + dp_offset,
|
||||
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
||||
)
|
||||
|
||||
# send out proxy tensors to the next stage
|
||||
if self.cur_batch:
|
||||
# FIXME(lsyin): remove this assert
|
||||
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
||||
self.pp_group.send_tensor_dict(
|
||||
result.pp_hidden_states_proxy_tensors.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
pp_outputs = next_pp_outputs
|
||||
|
||||
# When the server is idle, self-check and re-init some states
|
||||
if server_is_idle:
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.self_check_during_idle()
|
||||
|
||||
@DynamicGradMode()
|
||||
def event_loop_pp_disagg_prefill(self):
|
||||
"""
|
||||
An event loop for the prefill server in pipeline parallelism.
|
||||
|
||||
Rules:
|
||||
1. Each stage runs in the same order and is notified by the previous stage.
|
||||
2. Each send/recv operation is blocking and matched by the neighboring stage.
|
||||
|
||||
Regular Schedule:
|
||||
====================================================================
|
||||
Stage i | Stage i+1
|
||||
send ith req | recv ith req
|
||||
send ith proxy | recv ith proxy
|
||||
send prev (i+1)th carry | recv prev (i+1)th carry
|
||||
====================================================================
|
||||
|
||||
Prefill Server Schedule:
|
||||
====================================================================
|
||||
Stage i | Stage i+1
|
||||
send ith req | recv ith req
|
||||
send ith bootstrap req | recv ith bootstrap req
|
||||
send ith transferred req | recv ith transferred req
|
||||
send ith proxy | recv ith proxy
|
||||
send prev (i+1)th carry | recv prev (i+1)th carry
|
||||
send prev (i+1)th release req | recv prev (i+1)th release req
|
||||
====================================================================
|
||||
|
||||
There are two additional elements compared to the regular schedule:
|
||||
|
||||
1. Bootstrap Requests:
|
||||
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
|
||||
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
|
||||
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
|
||||
|
||||
2. Transferred Requests + Release Requests:
|
||||
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
|
||||
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
|
||||
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
|
||||
"""
|
||||
mbs = [None] * self.pp_size
|
||||
last_mbs = [None] * self.pp_size
|
||||
self.running_mbs = [
|
||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||
]
|
||||
pp_outputs: Optional[PPProxyTensors] = None
|
||||
|
||||
# Either success or failed
|
||||
bootstrapped_rids: List[str] = []
|
||||
transferred_rids: List[str] = []
|
||||
release_rids: Optional[List[str]] = None
|
||||
|
||||
# transferred microbatch
|
||||
tmbs = [None] * self.pp_size
|
||||
|
||||
ENABLE_RELEASE = True # For debug
|
||||
|
||||
while True:
|
||||
server_is_idle = True
|
||||
|
||||
for mb_id in range(self.pp_size):
|
||||
self.running_batch = self.running_mbs[mb_id]
|
||||
self.last_batch = last_mbs[mb_id]
|
||||
|
||||
recv_reqs = self.recv_requests()
|
||||
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
if self.pp_group.is_first_rank:
|
||||
# First rank, pop the bootstrap reqs from the bootstrap queue
|
||||
bootstrapped_reqs, failed_reqs = (
|
||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
||||
return_failed_reqs=True
|
||||
)
|
||||
)
|
||||
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
|
||||
req.rid for req in failed_reqs
|
||||
]
|
||||
self.waiting_queue.extend(bootstrapped_reqs)
|
||||
else:
|
||||
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
|
||||
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
|
||||
bootstrapped_reqs = (
|
||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
||||
rids_to_check=bootstrapped_rids
|
||||
)
|
||||
)
|
||||
self.waiting_queue.extend(bootstrapped_reqs)
|
||||
|
||||
if self.pp_group.is_first_rank:
|
||||
transferred_rids = self.get_transferred_rids()
|
||||
# if other ranks,
|
||||
else:
|
||||
# 1. recv previous stage's transferred reqs info
|
||||
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
|
||||
# 2. get the current stage's transferred reqs info
|
||||
curr_transferred_rids = self.get_transferred_rids()
|
||||
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
|
||||
transferred_rids = list(
|
||||
set(prev_transferred_rids) & set(curr_transferred_rids)
|
||||
)
|
||||
|
||||
tmbs[mb_id] = transferred_rids
|
||||
|
||||
self.process_prefill_chunk()
|
||||
mbs[mb_id] = self.get_new_batch_prefill()
|
||||
self.running_mbs[mb_id] = self.running_batch
|
||||
|
||||
self.cur_batch = mbs[mb_id]
|
||||
if self.cur_batch:
|
||||
server_is_idle = False
|
||||
result = self.run_batch(self.cur_batch)
|
||||
|
||||
# send the outputs to the next step
|
||||
if self.pp_group.is_last_rank:
|
||||
if self.cur_batch:
|
||||
next_token_ids = result.next_token_ids
|
||||
pp_outputs = PPProxyTensors(
|
||||
{
|
||||
"next_token_ids": next_token_ids,
|
||||
}
|
||||
)
|
||||
# send the output from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
if ENABLE_RELEASE:
|
||||
if self.pp_group.is_last_rank:
|
||||
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
|
||||
release_rids = transferred_rids
|
||||
# send to the first rank
|
||||
self.send_pyobj_to_next_stage(release_rids)
|
||||
|
||||
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
||||
next_mb_id = (mb_id + 1) % self.pp_size
|
||||
next_pp_outputs = None
|
||||
next_release_rids = None
|
||||
|
||||
if mbs[next_mb_id] is not None:
|
||||
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
||||
self.pp_group.recv_tensor_dict(
|
||||
all_gather_group=self.attn_tp_group
|
||||
)
|
||||
)
|
||||
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||
output_result = GenerationBatchResult(
|
||||
logits_output=None,
|
||||
pp_hidden_states_proxy_tensors=None,
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=None,
|
||||
extend_logprob_start_len_per_req=None,
|
||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||
)
|
||||
self.process_batch_result_disagg_prefill(
|
||||
mbs[next_mb_id], output_result
|
||||
)
|
||||
|
||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||
|
||||
if ENABLE_RELEASE:
|
||||
if tmbs[next_mb_id] is not None:
|
||||
# recv consensus rids from the previous rank
|
||||
next_release_rids = self.recv_pyobj_from_prev_stage()
|
||||
self.process_disagg_prefill_inflight_queue(next_release_rids)
|
||||
|
||||
# carry the outputs to the next stage
|
||||
if not self.pp_group.is_last_rank:
|
||||
if pp_outputs:
|
||||
# send the outputs from the last round to let the next stage worker run post processing
|
||||
self.pp_group.send_tensor_dict(
|
||||
pp_outputs.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
if ENABLE_RELEASE:
|
||||
if release_rids is not None:
|
||||
self.send_pyobj_to_next_stage(release_rids)
|
||||
|
||||
if not self.pp_group.is_last_rank:
|
||||
# send out reqs to the next stage
|
||||
self.send_pyobj_to_next_stage(recv_reqs)
|
||||
self.send_pyobj_to_next_stage(bootstrapped_rids)
|
||||
self.send_pyobj_to_next_stage(transferred_rids)
|
||||
|
||||
# send out proxy tensors to the next stage
|
||||
if self.cur_batch:
|
||||
# FIXME(lsyin): remove this assert
|
||||
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
||||
self.pp_group.send_tensor_dict(
|
||||
result.pp_hidden_states_proxy_tensors.tensors,
|
||||
all_gather_group=self.attn_tp_group,
|
||||
)
|
||||
|
||||
pp_outputs = next_pp_outputs
|
||||
release_rids = next_release_rids
|
||||
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
if not ENABLE_RELEASE:
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
self.process_disagg_prefill_inflight_queue()
|
||||
|
||||
# When the server is idle, self-check and re-init some states
|
||||
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
|
||||
self.check_memory()
|
||||
self.check_tree_cache()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
@@ -1,18 +1,95 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.overlap_utils import FutureIndices
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class GenerationBatchResult:
|
||||
logits_output: Optional[LogitsProcessorOutput] = None
|
||||
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
|
||||
next_token_ids: Optional[torch.Tensor] = None
|
||||
num_accepted_tokens: Optional[int] = None
|
||||
can_run_cuda_graph: bool = False
|
||||
|
||||
# For output processing
|
||||
extend_input_len_per_req: Optional[List[int]] = None
|
||||
extend_logprob_start_len_per_req: Optional[List[int]] = None
|
||||
|
||||
# For overlap scheduling
|
||||
copy_done: Optional[torch.cuda.Event] = None
|
||||
delay_sample_func: Optional[callable] = None
|
||||
future_indices: Optional[FutureIndices] = None
|
||||
|
||||
# FIXME(lsyin): maybe move to a better place?
|
||||
# sync path: forward stream -> output processor
|
||||
accept_lens: Optional[torch.Tensor] = None
|
||||
allocate_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# relay path: forward stream -> next step forward
|
||||
next_draft_input: Optional[EagleDraftInput] = None
|
||||
|
||||
def copy_to_cpu(self, return_logprob: bool = False):
|
||||
"""Copy tensors to CPU in overlap scheduling.
|
||||
Only the tensors which are needed for processing results are copied,
|
||||
e.g., next_token_ids, logits outputs
|
||||
"""
|
||||
if return_logprob:
|
||||
if self.logits_output.next_token_logits is not None:
|
||||
self.logits_output.next_token_logits = (
|
||||
self.logits_output.next_token_logits.to("cpu", non_blocking=True)
|
||||
)
|
||||
if self.logits_output.input_token_logprobs is not None:
|
||||
self.logits_output.input_token_logprobs = (
|
||||
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||
)
|
||||
if self.logits_output.hidden_states is not None:
|
||||
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
|
||||
|
||||
if self.accept_lens is not None:
|
||||
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
|
||||
|
||||
if self.allocate_lens is not None:
|
||||
self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
|
||||
|
||||
self.copy_done.record()
|
||||
|
||||
@classmethod
|
||||
def from_pp_proxy(
|
||||
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
||||
):
|
||||
# TODO(lsyin): refactor PP and avoid using dict
|
||||
proxy_dict = next_pp_outputs.tensors
|
||||
return cls(
|
||||
logits_output=logits_output,
|
||||
pp_hidden_states_proxy_tensors=None,
|
||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
||||
extend_logprob_start_len_per_req=proxy_dict.get(
|
||||
"extend_logprob_start_len_per_req", None
|
||||
),
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
)
|
||||
|
||||
|
||||
def validate_input_length(
|
||||
req: Req, max_req_input_len: int, allow_auto_truncate: bool
|
||||
) -> Optional[str]:
|
||||
|
||||
Reference in New Issue
Block a user