diff --git a/python/sglang/srt/disaggregation/decode.py b/python/sglang/srt/disaggregation/decode.py index 628858612..62f0dfd42 100644 --- a/python/sglang/srt/disaggregation/decode.py +++ b/python/sglang/srt/disaggregation/decode.py @@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import ( poll_and_all_reduce, prepare_abort, ) +from sglang.srt.managers.schedule_batch import FINISH_ABORT from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardMode @@ -321,11 +322,15 @@ class DecodeTransferQueue: gloo_group: ProcessGroup, req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator, metadata_buffers: torch.Tensor, + scheduler: Scheduler, + tree_cache: BasePrefixCache, ): self.queue: List[DecodeRequest] = [] self.gloo_group = gloo_group self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator self.metadata_buffers = metadata_buffers + self.scheduler = scheduler + self.tree_cache = tree_cache def add(self, req_conn: DecodeRequest) -> None: self.queue.append(req_conn) @@ -341,6 +346,14 @@ class DecodeTransferQueue: [decode_req.kv_receiver for decode_req in self.queue], self.gloo_group ) + # First, remove all failed requests from the queue + for i, decode_req in enumerate(self.queue): + if isinstance(decode_req.req.finished_reason, FINISH_ABORT): + self.scheduler.stream_output( + [decode_req.req], decode_req.req.return_logprob + ) + indices_to_remove.add(i) + transferred_reqs = [] indices_to_remove = set() for i, (decode_req, poll) in enumerate(zip(self.queue, polls)): @@ -396,95 +409,6 @@ class DecodeTransferQueue: return transferred_reqs -class ScheduleBatchDisaggregationDecodeMixin: - - def prepare_for_prebuilt_extend(self: ScheduleBatch): - """ - Prepare a prebuilt extend by populate metadata - Adapted from .prepare_for_extend(). - """ - - self.forward_mode = ForwardMode.EXTEND - reqs = self.reqs - input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] - extend_num_tokens = sum(len(ids) for ids in input_ids) - seq_lens = [] - pre_lens = [] - req_pool_indices = [] - - # Pre-calculate total size - total_size = sum(req.extend_input_len for req in reqs) - out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device) - - # Fill the tensor in one pass - offset = 0 - for i, req in enumerate(reqs): - req_pool_indices.append(req.req_pool_idx) - - chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][ - : req.extend_input_len - ] - assert ( - offset + req.extend_input_len <= total_size - ), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}" - out_cache_loc[offset : offset + req.extend_input_len] = chunk - offset += req.extend_input_len - - pre_len = len(req.prefix_indices) - seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1) - seq_lens.append(seq_len) - if len(req.output_ids) == 0: - assert ( - seq_len - pre_len == req.extend_input_len - ), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}" - - req.cached_tokens += pre_len - req.already_computed - req.already_computed = seq_len - req.is_retracted = False - pre_lens.append(pre_len) - req.extend_logprob_start_len = 0 - - extend_input_logprob_token_ids = None - - # Set fields - self.input_ids = torch.tensor( - sum(input_ids, []), dtype=torch.int32, device=self.device - ) - self.req_pool_indices = torch.tensor( - req_pool_indices, dtype=torch.int64, device=self.device - ) - self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device) - self.out_cache_loc = out_cache_loc - self.seq_lens_sum = sum(seq_lens) - self.extend_num_tokens = extend_num_tokens - self.prefix_lens = [len(r.prefix_indices) for r in reqs] - self.extend_lens = [r.extend_input_len for r in reqs] - self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] - self.extend_input_logprob_token_ids = extend_input_logprob_token_ids - - # Build sampling info - self.sampling_info = SamplingBatchInfo.from_schedule_batch( - self, - self.model_config.vocab_size, - ) - - def process_prebuilt_extend( - self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig - ): - """Assign the buffered last input id to schedule batch""" - self.output_ids = [] - for req in self.reqs: - if req.output_ids and len(req.output_ids) > 0: - # resumed retracted req - self.output_ids.append(req.output_ids[-1]) - else: - assert req.transferred_output_id is not None - req.output_ids.append(req.transferred_output_id) - self.output_ids.append(req.transferred_output_id) - self.tree_cache.cache_unfinished_req(req) - self.output_ids = torch.tensor(self.output_ids, device=self.device) - - class SchedulerDisaggregationDecodeMixin: def _prepare_idle_batch_and_run(self, batch, delay_process=False): diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py new file mode 100644 index 000000000..0c2e763c2 --- /dev/null +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import torch + +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from sglang.srt.configs.model_config import ModelConfig + from sglang.srt.managers.schedule_batch import ScheduleBatch + from sglang.srt.server_args import ServerArgs + + +class ScheduleBatchDisaggregationDecodeMixin: + + def prepare_for_prebuilt_extend(self: ScheduleBatch): + """ + Prepare a prebuilt extend by populate metadata + Adapted from .prepare_for_extend(). + """ + + self.forward_mode = ForwardMode.EXTEND + reqs = self.reqs + input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs] + extend_num_tokens = sum(len(ids) for ids in input_ids) + seq_lens = [] + pre_lens = [] + req_pool_indices = [] + + # Pre-calculate total size + total_size = sum(req.extend_input_len for req in reqs) + out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device) + + # Fill the tensor in one pass + offset = 0 + for i, req in enumerate(reqs): + req_pool_indices.append(req.req_pool_idx) + + chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][ + : req.extend_input_len + ] + assert ( + offset + req.extend_input_len <= total_size + ), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}" + out_cache_loc[offset : offset + req.extend_input_len] = chunk + offset += req.extend_input_len + + pre_len = len(req.prefix_indices) + seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1) + seq_lens.append(seq_len) + if len(req.output_ids) == 0: + assert ( + seq_len - pre_len == req.extend_input_len + ), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}" + + req.cached_tokens += pre_len - req.already_computed + req.already_computed = seq_len + req.is_retracted = False + pre_lens.append(pre_len) + req.extend_logprob_start_len = 0 + + extend_input_logprob_token_ids = None + + # Set fields + self.input_ids = torch.tensor( + sum(input_ids, []), dtype=torch.int32, device=self.device + ) + self.req_pool_indices = torch.tensor( + req_pool_indices, dtype=torch.int64, device=self.device + ) + self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device) + self.out_cache_loc = out_cache_loc + self.seq_lens_sum = sum(seq_lens) + self.extend_num_tokens = extend_num_tokens + self.prefix_lens = [len(r.prefix_indices) for r in reqs] + self.extend_lens = [r.extend_input_len for r in reqs] + self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs] + self.extend_input_logprob_token_ids = extend_input_logprob_token_ids + + # Build sampling info + self.sampling_info = SamplingBatchInfo.from_schedule_batch( + self, + self.model_config.vocab_size, + ) + + def process_prebuilt_extend( + self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig + ): + """Assign the buffered last input id to schedule batch""" + self.output_ids = [] + for req in self.reqs: + if req.output_ids and len(req.output_ids) > 0: + # resumed retracted req + self.output_ids.append(req.output_ids[-1]) + else: + assert req.transferred_output_id is not None + req.output_ids.append(req.transferred_output_id) + self.output_ids.append(req.transferred_output_id) + self.tree_cache.cache_unfinished_req(req) + self.output_ids = torch.tensor(self.output_ids, device=self.device) diff --git a/python/sglang/srt/disaggregation/utils.py b/python/sglang/srt/disaggregation/utils.py index 69348d4ff..5a5eb4780 100644 --- a/python/sglang/srt/disaggregation/utils.py +++ b/python/sglang/srt/disaggregation/utils.py @@ -1,6 +1,8 @@ from __future__ import annotations import dataclasses +import os +import random import warnings from collections import deque from enum import Enum @@ -15,6 +17,9 @@ from sglang.srt.utils import get_ip FakeBootstrapHost = "2.2.2.2" +# env var for testing failure, convert to float explicitly +FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0)) + class DisaggregationMode(Enum): NULL = "null" @@ -23,7 +28,16 @@ class DisaggregationMode(Enum): def poll_and_all_reduce(pollers, gloo_group): - polls = [int(poller.poll()) for poller in pollers] + # at a certain prob, the poll is failed to simulate failure + if FAILURE_PROB > 0: + from sglang.srt.disaggregation.base import KVPoll + + polls = [ + int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll()) + for poller in pollers + ] + else: + polls = [int(poller.poll()) for poller in pollers] tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu") dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group) return tensor_to_reduce.tolist() diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index cd780b1ac..abc466fa9 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -48,7 +48,9 @@ from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject from sglang.srt.disaggregation.base import BaseKVSender -from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin +from sglang.srt.disaggregation.decode_schedule_batch_mixin import ( + ScheduleBatchDisaggregationDecodeMixin, +) from sglang.srt.layers.multimodal import gpu_tensor_hash from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache from sglang.srt.mem_cache.chunk_cache import ChunkCache diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 62af0c6b3..33e599208 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -582,6 +582,8 @@ class Scheduler( gloo_group=self.attn_tp_cpu_group, req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator, metadata_buffers=metadata_buffers, + scheduler=self, + tree_cache=self.tree_cache, ) # The decode requests pending for pre-allocation diff --git a/python/sglang/srt/mem_cache/chunk_cache.py b/python/sglang/srt/mem_cache/chunk_cache.py index 2c51cb855..b09d3723b 100644 --- a/python/sglang/srt/mem_cache/chunk_cache.py +++ b/python/sglang/srt/mem_cache/chunk_cache.py @@ -38,7 +38,9 @@ class ChunkCache(BasePrefixCache): def cache_finished_req(self, req: Req): kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1 + req.req_pool_idx, + # For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids + : len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0), ] self.req_to_token_pool.free(req.req_pool_idx) self.token_to_kv_pool_allocator.free(kv_indices)