[code move] move pp into a separate mixin (#11838)
This commit is contained in:
@@ -53,13 +53,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
NSATokenToKVPool,
|
NSATokenToKVPool,
|
||||||
SWAKVPool,
|
SWAKVPool,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
from sglang.srt.utils import broadcast_pyobj, point_to_point_pyobj, require_mlp_sync
|
||||||
from sglang.srt.utils import (
|
|
||||||
DynamicGradMode,
|
|
||||||
broadcast_pyobj,
|
|
||||||
point_to_point_pyobj,
|
|
||||||
require_mlp_sync,
|
|
||||||
)
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
@@ -685,218 +679,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
return
|
return
|
||||||
req.disagg_kv_sender.send(page_indices, state_indices)
|
req.disagg_kv_sender.send(page_indices, state_indices)
|
||||||
|
|
||||||
# PP
|
|
||||||
@DynamicGradMode()
|
|
||||||
def event_loop_pp_disagg_prefill(self: Scheduler):
|
|
||||||
"""
|
|
||||||
An event loop for the prefill server in pipeline parallelism.
|
|
||||||
|
|
||||||
Rules:
|
|
||||||
1. Each stage runs in the same order and is notified by the previous stage.
|
|
||||||
2. Each send/recv operation is blocking and matched by the neighboring stage.
|
|
||||||
|
|
||||||
Regular Schedule:
|
|
||||||
====================================================================
|
|
||||||
Stage i | Stage i+1
|
|
||||||
send ith req | recv ith req
|
|
||||||
send ith proxy | recv ith proxy
|
|
||||||
send prev (i+1)th carry | recv prev (i+1)th carry
|
|
||||||
====================================================================
|
|
||||||
|
|
||||||
Prefill Server Schedule:
|
|
||||||
====================================================================
|
|
||||||
Stage i | Stage i+1
|
|
||||||
send ith req | recv ith req
|
|
||||||
send ith bootstrap req | recv ith bootstrap req
|
|
||||||
send ith transferred req | recv ith transferred req
|
|
||||||
send ith proxy | recv ith proxy
|
|
||||||
send prev (i+1)th carry | recv prev (i+1)th carry
|
|
||||||
send prev (i+1)th release req | recv prev (i+1)th release req
|
|
||||||
====================================================================
|
|
||||||
|
|
||||||
There are two additional elements compared to the regular schedule:
|
|
||||||
|
|
||||||
1. Bootstrap Requests:
|
|
||||||
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
|
|
||||||
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
|
|
||||||
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
|
|
||||||
|
|
||||||
2. Transferred Requests + Release Requests:
|
|
||||||
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
|
|
||||||
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
|
|
||||||
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
|
|
||||||
"""
|
|
||||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
|
||||||
|
|
||||||
mbs = [None] * self.pp_size
|
|
||||||
last_mbs = [None] * self.pp_size
|
|
||||||
self.running_mbs = [
|
|
||||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
|
||||||
]
|
|
||||||
pp_outputs: Optional[PPProxyTensors] = None
|
|
||||||
|
|
||||||
# Either success or failed
|
|
||||||
bootstrapped_rids: List[str] = []
|
|
||||||
transferred_rids: List[str] = []
|
|
||||||
release_rids: Optional[List[str]] = None
|
|
||||||
|
|
||||||
# transferred microbatch
|
|
||||||
tmbs = [None] * self.pp_size
|
|
||||||
|
|
||||||
ENABLE_RELEASE = True # For debug
|
|
||||||
|
|
||||||
while True:
|
|
||||||
server_is_idle = True
|
|
||||||
|
|
||||||
for mb_id in range(self.pp_size):
|
|
||||||
self.running_batch = self.running_mbs[mb_id]
|
|
||||||
self.last_batch = last_mbs[mb_id]
|
|
||||||
|
|
||||||
recv_reqs = self.recv_requests()
|
|
||||||
|
|
||||||
self.process_input_requests(recv_reqs)
|
|
||||||
|
|
||||||
if self.pp_group.is_first_rank:
|
|
||||||
# First rank, pop the bootstrap reqs from the bootstrap queue
|
|
||||||
bootstrapped_reqs, failed_reqs = (
|
|
||||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
|
||||||
return_failed_reqs=True
|
|
||||||
)
|
|
||||||
)
|
|
||||||
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
|
|
||||||
req.rid for req in failed_reqs
|
|
||||||
]
|
|
||||||
self.waiting_queue.extend(bootstrapped_reqs)
|
|
||||||
else:
|
|
||||||
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
|
|
||||||
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
|
|
||||||
bootstrapped_reqs = (
|
|
||||||
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
|
||||||
rids_to_check=bootstrapped_rids
|
|
||||||
)
|
|
||||||
)
|
|
||||||
self.waiting_queue.extend(bootstrapped_reqs)
|
|
||||||
|
|
||||||
if self.pp_group.is_first_rank:
|
|
||||||
transferred_rids = self.get_transferred_rids()
|
|
||||||
# if other ranks,
|
|
||||||
else:
|
|
||||||
# 1. recv previous stage's transferred reqs info
|
|
||||||
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
|
|
||||||
# 2. get the current stage's transferred reqs info
|
|
||||||
curr_transferred_rids = self.get_transferred_rids()
|
|
||||||
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
|
|
||||||
transferred_rids = list(
|
|
||||||
set(prev_transferred_rids) & set(curr_transferred_rids)
|
|
||||||
)
|
|
||||||
|
|
||||||
tmbs[mb_id] = transferred_rids
|
|
||||||
|
|
||||||
self.process_prefill_chunk()
|
|
||||||
mbs[mb_id] = self.get_new_batch_prefill()
|
|
||||||
self.running_mbs[mb_id] = self.running_batch
|
|
||||||
|
|
||||||
self.cur_batch = mbs[mb_id]
|
|
||||||
if self.cur_batch:
|
|
||||||
server_is_idle = False
|
|
||||||
result = self.run_batch(self.cur_batch)
|
|
||||||
|
|
||||||
# send the outputs to the next step
|
|
||||||
if self.pp_group.is_last_rank:
|
|
||||||
if self.cur_batch:
|
|
||||||
next_token_ids = result.next_token_ids
|
|
||||||
pp_outputs = PPProxyTensors(
|
|
||||||
{
|
|
||||||
"next_token_ids": next_token_ids,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
# send the output from the last round to let the next stage worker run post processing
|
|
||||||
self.pp_group.send_tensor_dict(
|
|
||||||
pp_outputs.tensors,
|
|
||||||
all_gather_group=self.attn_tp_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
if ENABLE_RELEASE:
|
|
||||||
if self.pp_group.is_last_rank:
|
|
||||||
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
|
|
||||||
release_rids = transferred_rids
|
|
||||||
# send to the first rank
|
|
||||||
self.send_pyobj_to_next_stage(release_rids)
|
|
||||||
|
|
||||||
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
|
||||||
next_mb_id = (mb_id + 1) % self.pp_size
|
|
||||||
next_pp_outputs = None
|
|
||||||
next_release_rids = None
|
|
||||||
|
|
||||||
if mbs[next_mb_id] is not None:
|
|
||||||
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
|
||||||
self.pp_group.recv_tensor_dict(
|
|
||||||
all_gather_group=self.attn_tp_group
|
|
||||||
)
|
|
||||||
)
|
|
||||||
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
|
||||||
output_result = GenerationBatchResult(
|
|
||||||
logits_output=None,
|
|
||||||
pp_hidden_states_proxy_tensors=None,
|
|
||||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
|
||||||
extend_input_len_per_req=None,
|
|
||||||
extend_logprob_start_len_per_req=None,
|
|
||||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
|
||||||
)
|
|
||||||
self.process_batch_result_disagg_prefill(
|
|
||||||
mbs[next_mb_id], output_result
|
|
||||||
)
|
|
||||||
|
|
||||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
|
||||||
|
|
||||||
if ENABLE_RELEASE:
|
|
||||||
if tmbs[next_mb_id] is not None:
|
|
||||||
# recv consensus rids from the previous rank
|
|
||||||
next_release_rids = self.recv_pyobj_from_prev_stage()
|
|
||||||
self.process_disagg_prefill_inflight_queue(next_release_rids)
|
|
||||||
|
|
||||||
# carry the outputs to the next stage
|
|
||||||
if not self.pp_group.is_last_rank:
|
|
||||||
if pp_outputs:
|
|
||||||
# send the outputs from the last round to let the next stage worker run post processing
|
|
||||||
self.pp_group.send_tensor_dict(
|
|
||||||
pp_outputs.tensors,
|
|
||||||
all_gather_group=self.attn_tp_group,
|
|
||||||
)
|
|
||||||
if ENABLE_RELEASE:
|
|
||||||
if release_rids is not None:
|
|
||||||
self.send_pyobj_to_next_stage(release_rids)
|
|
||||||
|
|
||||||
if not self.pp_group.is_last_rank:
|
|
||||||
# send out reqs to the next stage
|
|
||||||
self.send_pyobj_to_next_stage(recv_reqs)
|
|
||||||
self.send_pyobj_to_next_stage(bootstrapped_rids)
|
|
||||||
self.send_pyobj_to_next_stage(transferred_rids)
|
|
||||||
|
|
||||||
# send out proxy tensors to the next stage
|
|
||||||
if self.cur_batch:
|
|
||||||
# FIXME(lsyin): remove this assert
|
|
||||||
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
|
||||||
self.pp_group.send_tensor_dict(
|
|
||||||
result.pp_hidden_states_proxy_tensors.tensors,
|
|
||||||
all_gather_group=self.attn_tp_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
pp_outputs = next_pp_outputs
|
|
||||||
release_rids = next_release_rids
|
|
||||||
|
|
||||||
self.running_batch.batch_is_full = False
|
|
||||||
|
|
||||||
if not ENABLE_RELEASE:
|
|
||||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
|
||||||
self.process_disagg_prefill_inflight_queue()
|
|
||||||
|
|
||||||
# When the server is idle, self-check and re-init some states
|
|
||||||
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
|
|
||||||
self.check_memory()
|
|
||||||
self.check_tree_cache()
|
|
||||||
self.new_token_ratio = self.init_new_token_ratio
|
|
||||||
|
|
||||||
def send_pyobj_to_next_stage(self, data):
|
def send_pyobj_to_next_stage(self, data):
|
||||||
if self.attn_tp_rank == 0:
|
if self.attn_tp_rank == 0:
|
||||||
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ from sglang.srt.utils import flatten_nested_list
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||||
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
|
||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
@@ -1527,8 +1528,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
if self.is_v2_eagle:
|
if self.is_v2_eagle:
|
||||||
# TODO(spec-v2): all v2 spec should go through this path
|
# TODO(spec-v2): all v2 spec should go through this path
|
||||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
|
||||||
|
|
||||||
draft_input: EagleDraftInput = self.spec_info
|
draft_input: EagleDraftInput = self.spec_info
|
||||||
draft_input.prepare_for_decode(self)
|
draft_input.prepare_for_decode(self)
|
||||||
|
|
||||||
@@ -1585,8 +1584,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
def maybe_wait_verify_done(self):
|
def maybe_wait_verify_done(self):
|
||||||
if self.is_v2_eagle:
|
if self.is_v2_eagle:
|
||||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
|
||||||
|
|
||||||
draft_input: EagleDraftInput = self.spec_info
|
draft_input: EagleDraftInput = self.spec_info
|
||||||
if draft_input.verify_done is not None:
|
if draft_input.verify_done is not None:
|
||||||
draft_input.verify_done.synchronize()
|
draft_input.verify_done.synchronize()
|
||||||
|
|||||||
@@ -63,7 +63,6 @@ from sglang.srt.distributed import get_pp_group, get_world_group
|
|||||||
from sglang.srt.environ import envs
|
from sglang.srt.environ import envs
|
||||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
||||||
from sglang.srt.layers.moe import initialize_moe_config
|
from sglang.srt.layers.moe import initialize_moe_config
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
AbortReq,
|
AbortReq,
|
||||||
@@ -114,7 +113,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.mm_utils import init_embedding_cache
|
from sglang.srt.managers.mm_utils import init_embedding_cache
|
||||||
from sglang.srt.managers.overlap_utils import FutureIndices, FutureMap
|
from sglang.srt.managers.overlap_utils import FutureMap
|
||||||
from sglang.srt.managers.schedule_batch import (
|
from sglang.srt.managers.schedule_batch import (
|
||||||
FINISH_ABORT,
|
FINISH_ABORT,
|
||||||
ModelWorkerBatch,
|
ModelWorkerBatch,
|
||||||
@@ -136,22 +135,21 @@ from sglang.srt.managers.scheduler_metrics_mixin import (
|
|||||||
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
||||||
SchedulerOutputProcessorMixin,
|
SchedulerOutputProcessorMixin,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.scheduler_pp_mixin import SchedulerPPMixin
|
||||||
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
||||||
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
|
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
|
||||||
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
||||||
SchedulerUpdateWeightsMixin,
|
SchedulerUpdateWeightsMixin,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.session_controller import Session
|
from sglang.srt.managers.session_controller import Session
|
||||||
from sglang.srt.managers.utils import validate_input_length
|
from sglang.srt.managers.utils import GenerationBatchResult, validate_input_length
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||||
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
from sglang.srt.mem_cache.mamba_radix_cache import MambaRadixCache
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||||
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
|
||||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
|
from sglang.srt.server_args import PortArgs, ServerArgs, get_global_server_args
|
||||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.tracing.trace import (
|
from sglang.srt.tracing.trace import (
|
||||||
process_tracing_init,
|
process_tracing_init,
|
||||||
@@ -198,77 +196,6 @@ TEST_RETRACT = get_bool_env_var("SGLANG_TEST_RETRACT")
|
|||||||
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
GRAMMAR_TIMEOUT = float(os.environ.get("SGLANG_GRAMMAR_TIMEOUT", 300))
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GenerationBatchResult:
|
|
||||||
logits_output: Optional[LogitsProcessorOutput] = None
|
|
||||||
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
|
|
||||||
next_token_ids: Optional[torch.Tensor] = None
|
|
||||||
num_accepted_tokens: Optional[int] = None
|
|
||||||
can_run_cuda_graph: bool = False
|
|
||||||
|
|
||||||
# For output processing
|
|
||||||
extend_input_len_per_req: Optional[List[int]] = None
|
|
||||||
extend_logprob_start_len_per_req: Optional[List[int]] = None
|
|
||||||
|
|
||||||
# For overlap scheduling
|
|
||||||
copy_done: Optional[torch.cuda.Event] = None
|
|
||||||
delay_sample_func: Optional[callable] = None
|
|
||||||
future_indices: Optional[FutureIndices] = None
|
|
||||||
|
|
||||||
# FIXME(lsyin): maybe move to a better place?
|
|
||||||
# sync path: forward stream -> output processor
|
|
||||||
accept_lens: Optional[torch.Tensor] = None
|
|
||||||
allocate_lens: Optional[torch.Tensor] = None
|
|
||||||
|
|
||||||
# relay path: forward stream -> next step forward
|
|
||||||
next_draft_input: Optional[EagleDraftInput] = None
|
|
||||||
|
|
||||||
def copy_to_cpu(self, return_logprob: bool = False):
|
|
||||||
"""Copy tensors to CPU in overlap scheduling.
|
|
||||||
Only the tensors which are needed for processing results are copied,
|
|
||||||
e.g., next_token_ids, logits outputs
|
|
||||||
"""
|
|
||||||
if return_logprob:
|
|
||||||
if self.logits_output.next_token_logits is not None:
|
|
||||||
self.logits_output.next_token_logits = (
|
|
||||||
self.logits_output.next_token_logits.to("cpu", non_blocking=True)
|
|
||||||
)
|
|
||||||
if self.logits_output.input_token_logprobs is not None:
|
|
||||||
self.logits_output.input_token_logprobs = (
|
|
||||||
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
|
||||||
)
|
|
||||||
if self.logits_output.hidden_states is not None:
|
|
||||||
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
|
|
||||||
"cpu", non_blocking=True
|
|
||||||
)
|
|
||||||
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
|
|
||||||
|
|
||||||
if self.accept_lens is not None:
|
|
||||||
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
|
|
||||||
|
|
||||||
if self.allocate_lens is not None:
|
|
||||||
self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
|
|
||||||
|
|
||||||
self.copy_done.record()
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pp_proxy(
|
|
||||||
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
|
||||||
):
|
|
||||||
# TODO(lsyin): refactor PP and avoid using dict
|
|
||||||
proxy_dict = next_pp_outputs.tensors
|
|
||||||
return cls(
|
|
||||||
logits_output=logits_output,
|
|
||||||
pp_hidden_states_proxy_tensors=None,
|
|
||||||
next_token_ids=next_pp_outputs["next_token_ids"],
|
|
||||||
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
|
||||||
extend_logprob_start_len_per_req=proxy_dict.get(
|
|
||||||
"extend_logprob_start_len_per_req", None
|
|
||||||
),
|
|
||||||
can_run_cuda_graph=can_run_cuda_graph,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingBatchResult:
|
class EmbeddingBatchResult:
|
||||||
embeddings: torch.Tensor
|
embeddings: torch.Tensor
|
||||||
@@ -281,6 +208,7 @@ class Scheduler(
|
|||||||
SchedulerMetricsMixin,
|
SchedulerMetricsMixin,
|
||||||
SchedulerDisaggregationDecodeMixin,
|
SchedulerDisaggregationDecodeMixin,
|
||||||
SchedulerDisaggregationPrefillMixin,
|
SchedulerDisaggregationPrefillMixin,
|
||||||
|
SchedulerPPMixin,
|
||||||
):
|
):
|
||||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||||
|
|
||||||
@@ -1058,128 +986,6 @@ class Scheduler(
|
|||||||
self.launch_batch_sample_if_needed(batch_result)
|
self.launch_batch_sample_if_needed(batch_result)
|
||||||
self.last_batch = batch
|
self.last_batch = batch
|
||||||
|
|
||||||
@DynamicGradMode()
|
|
||||||
def event_loop_pp(self):
|
|
||||||
"""A non-overlap scheduler loop for pipeline parallelism."""
|
|
||||||
mbs = [None] * self.pp_size
|
|
||||||
last_mbs = [None] * self.pp_size
|
|
||||||
self.running_mbs = [
|
|
||||||
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
|
||||||
]
|
|
||||||
pp_outputs: Optional[PPProxyTensors] = None
|
|
||||||
while True:
|
|
||||||
server_is_idle = True
|
|
||||||
for mb_id in range(self.pp_size):
|
|
||||||
self.running_batch = self.running_mbs[mb_id]
|
|
||||||
self.last_batch = last_mbs[mb_id]
|
|
||||||
|
|
||||||
recv_reqs = self.recv_requests()
|
|
||||||
self.process_input_requests(recv_reqs)
|
|
||||||
mbs[mb_id] = self.get_next_batch_to_run()
|
|
||||||
self.running_mbs[mb_id] = self.running_batch
|
|
||||||
|
|
||||||
self.cur_batch = mbs[mb_id]
|
|
||||||
if self.cur_batch:
|
|
||||||
server_is_idle = False
|
|
||||||
result = self.run_batch(self.cur_batch)
|
|
||||||
|
|
||||||
# (last rank) send the outputs to the next step
|
|
||||||
if self.pp_group.is_last_rank:
|
|
||||||
if self.cur_batch:
|
|
||||||
next_token_ids = result.next_token_ids
|
|
||||||
if self.cur_batch.return_logprob:
|
|
||||||
pp_outputs = PPProxyTensors(
|
|
||||||
{
|
|
||||||
"next_token_ids": next_token_ids,
|
|
||||||
"extend_input_len_per_req": result.extend_input_len_per_req,
|
|
||||||
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
|
||||||
}
|
|
||||||
| (
|
|
||||||
{
|
|
||||||
f"logits_output.{k}": v
|
|
||||||
for k, v in result.logits_output.__dict__.items()
|
|
||||||
}
|
|
||||||
if result.logits_output is not None
|
|
||||||
else {}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
pp_outputs = PPProxyTensors(
|
|
||||||
{
|
|
||||||
"next_token_ids": next_token_ids,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
# send the output from the last round to let the next stage worker run post processing
|
|
||||||
self.pp_group.send_tensor_dict(
|
|
||||||
pp_outputs.tensors,
|
|
||||||
all_gather_group=self.attn_tp_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
|
||||||
next_mb_id = (mb_id + 1) % self.pp_size
|
|
||||||
next_pp_outputs = None
|
|
||||||
if mbs[next_mb_id] is not None:
|
|
||||||
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
|
||||||
self.pp_group.recv_tensor_dict(
|
|
||||||
all_gather_group=self.attn_tp_group
|
|
||||||
)
|
|
||||||
)
|
|
||||||
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
|
||||||
logits_output_args = {
|
|
||||||
k[len("logits_output.") :]: v
|
|
||||||
for k, v in next_pp_outputs.tensors.items()
|
|
||||||
if k.startswith("logits_output.")
|
|
||||||
}
|
|
||||||
if len(logits_output_args) > 0:
|
|
||||||
logits_output = LogitsProcessorOutput(**logits_output_args)
|
|
||||||
else:
|
|
||||||
logits_output = None
|
|
||||||
|
|
||||||
output_result = GenerationBatchResult.from_pp_proxy(
|
|
||||||
logits_output=logits_output,
|
|
||||||
next_pp_outputs=next_pp_outputs,
|
|
||||||
can_run_cuda_graph=result.can_run_cuda_graph,
|
|
||||||
)
|
|
||||||
self.process_batch_result(mbs[next_mb_id], output_result)
|
|
||||||
last_mbs[next_mb_id] = mbs[next_mb_id]
|
|
||||||
|
|
||||||
# (not last rank)
|
|
||||||
if not self.pp_group.is_last_rank:
|
|
||||||
# carry the outputs to the next stage
|
|
||||||
# send the outputs from the last round to let the next stage worker run post processing
|
|
||||||
if pp_outputs:
|
|
||||||
self.pp_group.send_tensor_dict(
|
|
||||||
pp_outputs.tensors,
|
|
||||||
all_gather_group=self.attn_tp_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
# send out reqs to the next stage
|
|
||||||
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
|
||||||
if self.attn_tp_rank == 0:
|
|
||||||
point_to_point_pyobj(
|
|
||||||
recv_reqs,
|
|
||||||
self.pp_rank * self.tp_size + dp_offset,
|
|
||||||
self.world_group.device_group,
|
|
||||||
self.pp_rank * self.tp_size + dp_offset,
|
|
||||||
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
|
||||||
)
|
|
||||||
|
|
||||||
# send out proxy tensors to the next stage
|
|
||||||
if self.cur_batch:
|
|
||||||
# FIXME(lsyin): remove this assert
|
|
||||||
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
|
||||||
self.pp_group.send_tensor_dict(
|
|
||||||
result.pp_hidden_states_proxy_tensors.tensors,
|
|
||||||
all_gather_group=self.attn_tp_group,
|
|
||||||
)
|
|
||||||
|
|
||||||
pp_outputs = next_pp_outputs
|
|
||||||
|
|
||||||
# When the server is idle, self-check and re-init some states
|
|
||||||
if server_is_idle:
|
|
||||||
# When the server is idle, do self-check and re-init some states
|
|
||||||
self.self_check_during_idle()
|
|
||||||
|
|
||||||
def recv_requests(self) -> List[Req]:
|
def recv_requests(self) -> List[Req]:
|
||||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||||
|
|
||||||
|
|||||||
341
python/sglang/srt/managers/scheduler_pp_mixin.py
Normal file
341
python/sglang/srt/managers/scheduler_pp_mixin.py
Normal file
@@ -0,0 +1,341 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
|
from sglang.srt.managers.utils import GenerationBatchResult
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||||
|
from sglang.srt.utils import DynamicGradMode, point_to_point_pyobj
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerPPMixin:
|
||||||
|
|
||||||
|
@DynamicGradMode()
|
||||||
|
def event_loop_pp(self):
|
||||||
|
"""A non-overlap scheduler loop for pipeline parallelism."""
|
||||||
|
mbs = [None] * self.pp_size
|
||||||
|
last_mbs = [None] * self.pp_size
|
||||||
|
self.running_mbs = [
|
||||||
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||||
|
]
|
||||||
|
pp_outputs: Optional[PPProxyTensors] = None
|
||||||
|
while True:
|
||||||
|
server_is_idle = True
|
||||||
|
for mb_id in range(self.pp_size):
|
||||||
|
self.running_batch = self.running_mbs[mb_id]
|
||||||
|
self.last_batch = last_mbs[mb_id]
|
||||||
|
|
||||||
|
recv_reqs = self.recv_requests()
|
||||||
|
self.process_input_requests(recv_reqs)
|
||||||
|
mbs[mb_id] = self.get_next_batch_to_run()
|
||||||
|
self.running_mbs[mb_id] = self.running_batch
|
||||||
|
|
||||||
|
self.cur_batch = mbs[mb_id]
|
||||||
|
if self.cur_batch:
|
||||||
|
server_is_idle = False
|
||||||
|
result = self.run_batch(self.cur_batch)
|
||||||
|
|
||||||
|
# (last rank) send the outputs to the next step
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
if self.cur_batch:
|
||||||
|
next_token_ids = result.next_token_ids
|
||||||
|
if self.cur_batch.return_logprob:
|
||||||
|
pp_outputs = PPProxyTensors(
|
||||||
|
{
|
||||||
|
"next_token_ids": next_token_ids,
|
||||||
|
"extend_input_len_per_req": result.extend_input_len_per_req,
|
||||||
|
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
||||||
|
}
|
||||||
|
| (
|
||||||
|
{
|
||||||
|
f"logits_output.{k}": v
|
||||||
|
for k, v in result.logits_output.__dict__.items()
|
||||||
|
}
|
||||||
|
if result.logits_output is not None
|
||||||
|
else {}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
pp_outputs = PPProxyTensors(
|
||||||
|
{
|
||||||
|
"next_token_ids": next_token_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# send the output from the last round to let the next stage worker run post processing
|
||||||
|
self.pp_group.send_tensor_dict(
|
||||||
|
pp_outputs.tensors,
|
||||||
|
all_gather_group=self.attn_tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
||||||
|
next_mb_id = (mb_id + 1) % self.pp_size
|
||||||
|
next_pp_outputs = None
|
||||||
|
if mbs[next_mb_id] is not None:
|
||||||
|
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
||||||
|
self.pp_group.recv_tensor_dict(
|
||||||
|
all_gather_group=self.attn_tp_group
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||||
|
logits_output_args = {
|
||||||
|
k[len("logits_output.") :]: v
|
||||||
|
for k, v in next_pp_outputs.tensors.items()
|
||||||
|
if k.startswith("logits_output.")
|
||||||
|
}
|
||||||
|
if len(logits_output_args) > 0:
|
||||||
|
logits_output = LogitsProcessorOutput(**logits_output_args)
|
||||||
|
else:
|
||||||
|
logits_output = None
|
||||||
|
|
||||||
|
output_result = GenerationBatchResult.from_pp_proxy(
|
||||||
|
logits_output=logits_output,
|
||||||
|
next_pp_outputs=next_pp_outputs,
|
||||||
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||||
|
)
|
||||||
|
self.process_batch_result(mbs[next_mb_id], output_result)
|
||||||
|
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||||
|
|
||||||
|
# (not last rank)
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
# carry the outputs to the next stage
|
||||||
|
# send the outputs from the last round to let the next stage worker run post processing
|
||||||
|
if pp_outputs:
|
||||||
|
self.pp_group.send_tensor_dict(
|
||||||
|
pp_outputs.tensors,
|
||||||
|
all_gather_group=self.attn_tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
# send out reqs to the next stage
|
||||||
|
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||||
|
if self.attn_tp_rank == 0:
|
||||||
|
point_to_point_pyobj(
|
||||||
|
recv_reqs,
|
||||||
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
|
self.world_group.device_group,
|
||||||
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
|
(self.pp_rank + 1) * self.tp_size + dp_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
# send out proxy tensors to the next stage
|
||||||
|
if self.cur_batch:
|
||||||
|
# FIXME(lsyin): remove this assert
|
||||||
|
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
||||||
|
self.pp_group.send_tensor_dict(
|
||||||
|
result.pp_hidden_states_proxy_tensors.tensors,
|
||||||
|
all_gather_group=self.attn_tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
pp_outputs = next_pp_outputs
|
||||||
|
|
||||||
|
# When the server is idle, self-check and re-init some states
|
||||||
|
if server_is_idle:
|
||||||
|
# When the server is idle, do self-check and re-init some states
|
||||||
|
self.self_check_during_idle()
|
||||||
|
|
||||||
|
@DynamicGradMode()
|
||||||
|
def event_loop_pp_disagg_prefill(self):
|
||||||
|
"""
|
||||||
|
An event loop for the prefill server in pipeline parallelism.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
1. Each stage runs in the same order and is notified by the previous stage.
|
||||||
|
2. Each send/recv operation is blocking and matched by the neighboring stage.
|
||||||
|
|
||||||
|
Regular Schedule:
|
||||||
|
====================================================================
|
||||||
|
Stage i | Stage i+1
|
||||||
|
send ith req | recv ith req
|
||||||
|
send ith proxy | recv ith proxy
|
||||||
|
send prev (i+1)th carry | recv prev (i+1)th carry
|
||||||
|
====================================================================
|
||||||
|
|
||||||
|
Prefill Server Schedule:
|
||||||
|
====================================================================
|
||||||
|
Stage i | Stage i+1
|
||||||
|
send ith req | recv ith req
|
||||||
|
send ith bootstrap req | recv ith bootstrap req
|
||||||
|
send ith transferred req | recv ith transferred req
|
||||||
|
send ith proxy | recv ith proxy
|
||||||
|
send prev (i+1)th carry | recv prev (i+1)th carry
|
||||||
|
send prev (i+1)th release req | recv prev (i+1)th release req
|
||||||
|
====================================================================
|
||||||
|
|
||||||
|
There are two additional elements compared to the regular schedule:
|
||||||
|
|
||||||
|
1. Bootstrap Requests:
|
||||||
|
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
|
||||||
|
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
|
||||||
|
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
|
||||||
|
|
||||||
|
2. Transferred Requests + Release Requests:
|
||||||
|
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
|
||||||
|
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
|
||||||
|
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
|
||||||
|
"""
|
||||||
|
mbs = [None] * self.pp_size
|
||||||
|
last_mbs = [None] * self.pp_size
|
||||||
|
self.running_mbs = [
|
||||||
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||||
|
]
|
||||||
|
pp_outputs: Optional[PPProxyTensors] = None
|
||||||
|
|
||||||
|
# Either success or failed
|
||||||
|
bootstrapped_rids: List[str] = []
|
||||||
|
transferred_rids: List[str] = []
|
||||||
|
release_rids: Optional[List[str]] = None
|
||||||
|
|
||||||
|
# transferred microbatch
|
||||||
|
tmbs = [None] * self.pp_size
|
||||||
|
|
||||||
|
ENABLE_RELEASE = True # For debug
|
||||||
|
|
||||||
|
while True:
|
||||||
|
server_is_idle = True
|
||||||
|
|
||||||
|
for mb_id in range(self.pp_size):
|
||||||
|
self.running_batch = self.running_mbs[mb_id]
|
||||||
|
self.last_batch = last_mbs[mb_id]
|
||||||
|
|
||||||
|
recv_reqs = self.recv_requests()
|
||||||
|
|
||||||
|
self.process_input_requests(recv_reqs)
|
||||||
|
|
||||||
|
if self.pp_group.is_first_rank:
|
||||||
|
# First rank, pop the bootstrap reqs from the bootstrap queue
|
||||||
|
bootstrapped_reqs, failed_reqs = (
|
||||||
|
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
||||||
|
return_failed_reqs=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
|
||||||
|
req.rid for req in failed_reqs
|
||||||
|
]
|
||||||
|
self.waiting_queue.extend(bootstrapped_reqs)
|
||||||
|
else:
|
||||||
|
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
|
||||||
|
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
|
||||||
|
bootstrapped_reqs = (
|
||||||
|
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
||||||
|
rids_to_check=bootstrapped_rids
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.waiting_queue.extend(bootstrapped_reqs)
|
||||||
|
|
||||||
|
if self.pp_group.is_first_rank:
|
||||||
|
transferred_rids = self.get_transferred_rids()
|
||||||
|
# if other ranks,
|
||||||
|
else:
|
||||||
|
# 1. recv previous stage's transferred reqs info
|
||||||
|
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
|
||||||
|
# 2. get the current stage's transferred reqs info
|
||||||
|
curr_transferred_rids = self.get_transferred_rids()
|
||||||
|
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
|
||||||
|
transferred_rids = list(
|
||||||
|
set(prev_transferred_rids) & set(curr_transferred_rids)
|
||||||
|
)
|
||||||
|
|
||||||
|
tmbs[mb_id] = transferred_rids
|
||||||
|
|
||||||
|
self.process_prefill_chunk()
|
||||||
|
mbs[mb_id] = self.get_new_batch_prefill()
|
||||||
|
self.running_mbs[mb_id] = self.running_batch
|
||||||
|
|
||||||
|
self.cur_batch = mbs[mb_id]
|
||||||
|
if self.cur_batch:
|
||||||
|
server_is_idle = False
|
||||||
|
result = self.run_batch(self.cur_batch)
|
||||||
|
|
||||||
|
# send the outputs to the next step
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
if self.cur_batch:
|
||||||
|
next_token_ids = result.next_token_ids
|
||||||
|
pp_outputs = PPProxyTensors(
|
||||||
|
{
|
||||||
|
"next_token_ids": next_token_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# send the output from the last round to let the next stage worker run post processing
|
||||||
|
self.pp_group.send_tensor_dict(
|
||||||
|
pp_outputs.tensors,
|
||||||
|
all_gather_group=self.attn_tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ENABLE_RELEASE:
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
|
||||||
|
release_rids = transferred_rids
|
||||||
|
# send to the first rank
|
||||||
|
self.send_pyobj_to_next_stage(release_rids)
|
||||||
|
|
||||||
|
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
||||||
|
next_mb_id = (mb_id + 1) % self.pp_size
|
||||||
|
next_pp_outputs = None
|
||||||
|
next_release_rids = None
|
||||||
|
|
||||||
|
if mbs[next_mb_id] is not None:
|
||||||
|
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
||||||
|
self.pp_group.recv_tensor_dict(
|
||||||
|
all_gather_group=self.attn_tp_group
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||||
|
output_result = GenerationBatchResult(
|
||||||
|
logits_output=None,
|
||||||
|
pp_hidden_states_proxy_tensors=None,
|
||||||
|
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||||
|
extend_input_len_per_req=None,
|
||||||
|
extend_logprob_start_len_per_req=None,
|
||||||
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||||
|
)
|
||||||
|
self.process_batch_result_disagg_prefill(
|
||||||
|
mbs[next_mb_id], output_result
|
||||||
|
)
|
||||||
|
|
||||||
|
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||||
|
|
||||||
|
if ENABLE_RELEASE:
|
||||||
|
if tmbs[next_mb_id] is not None:
|
||||||
|
# recv consensus rids from the previous rank
|
||||||
|
next_release_rids = self.recv_pyobj_from_prev_stage()
|
||||||
|
self.process_disagg_prefill_inflight_queue(next_release_rids)
|
||||||
|
|
||||||
|
# carry the outputs to the next stage
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
if pp_outputs:
|
||||||
|
# send the outputs from the last round to let the next stage worker run post processing
|
||||||
|
self.pp_group.send_tensor_dict(
|
||||||
|
pp_outputs.tensors,
|
||||||
|
all_gather_group=self.attn_tp_group,
|
||||||
|
)
|
||||||
|
if ENABLE_RELEASE:
|
||||||
|
if release_rids is not None:
|
||||||
|
self.send_pyobj_to_next_stage(release_rids)
|
||||||
|
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
# send out reqs to the next stage
|
||||||
|
self.send_pyobj_to_next_stage(recv_reqs)
|
||||||
|
self.send_pyobj_to_next_stage(bootstrapped_rids)
|
||||||
|
self.send_pyobj_to_next_stage(transferred_rids)
|
||||||
|
|
||||||
|
# send out proxy tensors to the next stage
|
||||||
|
if self.cur_batch:
|
||||||
|
# FIXME(lsyin): remove this assert
|
||||||
|
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
||||||
|
self.pp_group.send_tensor_dict(
|
||||||
|
result.pp_hidden_states_proxy_tensors.tensors,
|
||||||
|
all_gather_group=self.attn_tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
pp_outputs = next_pp_outputs
|
||||||
|
release_rids = next_release_rids
|
||||||
|
|
||||||
|
self.running_batch.batch_is_full = False
|
||||||
|
|
||||||
|
if not ENABLE_RELEASE:
|
||||||
|
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||||
|
self.process_disagg_prefill_inflight_queue()
|
||||||
|
|
||||||
|
# When the server is idle, self-check and re-init some states
|
||||||
|
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
|
||||||
|
self.check_memory()
|
||||||
|
self.check_tree_cache()
|
||||||
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
@@ -1,18 +1,95 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
|
from sglang.srt.managers.overlap_utils import FutureIndices
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||||
|
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class GenerationBatchResult:
|
||||||
|
logits_output: Optional[LogitsProcessorOutput] = None
|
||||||
|
pp_hidden_states_proxy_tensors: Optional[PPProxyTensors] = None
|
||||||
|
next_token_ids: Optional[torch.Tensor] = None
|
||||||
|
num_accepted_tokens: Optional[int] = None
|
||||||
|
can_run_cuda_graph: bool = False
|
||||||
|
|
||||||
|
# For output processing
|
||||||
|
extend_input_len_per_req: Optional[List[int]] = None
|
||||||
|
extend_logprob_start_len_per_req: Optional[List[int]] = None
|
||||||
|
|
||||||
|
# For overlap scheduling
|
||||||
|
copy_done: Optional[torch.cuda.Event] = None
|
||||||
|
delay_sample_func: Optional[callable] = None
|
||||||
|
future_indices: Optional[FutureIndices] = None
|
||||||
|
|
||||||
|
# FIXME(lsyin): maybe move to a better place?
|
||||||
|
# sync path: forward stream -> output processor
|
||||||
|
accept_lens: Optional[torch.Tensor] = None
|
||||||
|
allocate_lens: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# relay path: forward stream -> next step forward
|
||||||
|
next_draft_input: Optional[EagleDraftInput] = None
|
||||||
|
|
||||||
|
def copy_to_cpu(self, return_logprob: bool = False):
|
||||||
|
"""Copy tensors to CPU in overlap scheduling.
|
||||||
|
Only the tensors which are needed for processing results are copied,
|
||||||
|
e.g., next_token_ids, logits outputs
|
||||||
|
"""
|
||||||
|
if return_logprob:
|
||||||
|
if self.logits_output.next_token_logits is not None:
|
||||||
|
self.logits_output.next_token_logits = (
|
||||||
|
self.logits_output.next_token_logits.to("cpu", non_blocking=True)
|
||||||
|
)
|
||||||
|
if self.logits_output.input_token_logprobs is not None:
|
||||||
|
self.logits_output.input_token_logprobs = (
|
||||||
|
self.logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
||||||
|
)
|
||||||
|
if self.logits_output.hidden_states is not None:
|
||||||
|
self.logits_output.hidden_states = self.logits_output.hidden_states.to(
|
||||||
|
"cpu", non_blocking=True
|
||||||
|
)
|
||||||
|
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
|
||||||
|
|
||||||
|
if self.accept_lens is not None:
|
||||||
|
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
|
||||||
|
|
||||||
|
if self.allocate_lens is not None:
|
||||||
|
self.allocate_lens = self.allocate_lens.to("cpu", non_blocking=True)
|
||||||
|
|
||||||
|
self.copy_done.record()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pp_proxy(
|
||||||
|
cls, logits_output, next_pp_outputs: PPProxyTensors, can_run_cuda_graph
|
||||||
|
):
|
||||||
|
# TODO(lsyin): refactor PP and avoid using dict
|
||||||
|
proxy_dict = next_pp_outputs.tensors
|
||||||
|
return cls(
|
||||||
|
logits_output=logits_output,
|
||||||
|
pp_hidden_states_proxy_tensors=None,
|
||||||
|
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||||
|
extend_input_len_per_req=proxy_dict.get("extend_input_len_per_req", None),
|
||||||
|
extend_logprob_start_len_per_req=proxy_dict.get(
|
||||||
|
"extend_logprob_start_len_per_req", None
|
||||||
|
),
|
||||||
|
can_run_cuda_graph=can_run_cuda_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def validate_input_length(
|
def validate_input_length(
|
||||||
req: Req, max_req_input_len: int, allow_auto_truncate: bool
|
req: Req, max_req_input_len: int, allow_auto_truncate: bool
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
|||||||
Reference in New Issue
Block a user