[PD] Fix failure abort (#6535)
This commit is contained in:
@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
poll_and_all_reduce,
|
||||
prepare_abort,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
@@ -321,11 +322,15 @@ class DecodeTransferQueue:
|
||||
gloo_group: ProcessGroup,
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: torch.Tensor,
|
||||
scheduler: Scheduler,
|
||||
tree_cache: BasePrefixCache,
|
||||
):
|
||||
self.queue: List[DecodeRequest] = []
|
||||
self.gloo_group = gloo_group
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.scheduler = scheduler
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
def add(self, req_conn: DecodeRequest) -> None:
|
||||
self.queue.append(req_conn)
|
||||
@@ -341,6 +346,14 @@ class DecodeTransferQueue:
|
||||
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
||||
)
|
||||
|
||||
# First, remove all failed requests from the queue
|
||||
for i, decode_req in enumerate(self.queue):
|
||||
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
||||
self.scheduler.stream_output(
|
||||
[decode_req.req], decode_req.req.return_logprob
|
||||
)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
transferred_reqs = []
|
||||
indices_to_remove = set()
|
||||
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||
@@ -396,95 +409,6 @@ class DecodeTransferQueue:
|
||||
return transferred_reqs
|
||||
|
||||
|
||||
class ScheduleBatchDisaggregationDecodeMixin:
|
||||
|
||||
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
||||
"""
|
||||
Prepare a prebuilt extend by populate metadata
|
||||
Adapted from .prepare_for_extend().
|
||||
"""
|
||||
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
reqs = self.reqs
|
||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||
seq_lens = []
|
||||
pre_lens = []
|
||||
req_pool_indices = []
|
||||
|
||||
# Pre-calculate total size
|
||||
total_size = sum(req.extend_input_len for req in reqs)
|
||||
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
||||
|
||||
# Fill the tensor in one pass
|
||||
offset = 0
|
||||
for i, req in enumerate(reqs):
|
||||
req_pool_indices.append(req.req_pool_idx)
|
||||
|
||||
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
: req.extend_input_len
|
||||
]
|
||||
assert (
|
||||
offset + req.extend_input_len <= total_size
|
||||
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
||||
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
||||
offset += req.extend_input_len
|
||||
|
||||
pre_len = len(req.prefix_indices)
|
||||
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
||||
seq_lens.append(seq_len)
|
||||
if len(req.output_ids) == 0:
|
||||
assert (
|
||||
seq_len - pre_len == req.extend_input_len
|
||||
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
||||
|
||||
req.cached_tokens += pre_len - req.already_computed
|
||||
req.already_computed = seq_len
|
||||
req.is_retracted = False
|
||||
pre_lens.append(pre_len)
|
||||
req.extend_logprob_start_len = 0
|
||||
|
||||
extend_input_logprob_token_ids = None
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(
|
||||
sum(input_ids, []), dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.req_pool_indices = torch.tensor(
|
||||
req_pool_indices, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||
|
||||
# Build sampling info
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
)
|
||||
|
||||
def process_prebuilt_extend(
|
||||
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
||||
):
|
||||
"""Assign the buffered last input id to schedule batch"""
|
||||
self.output_ids = []
|
||||
for req in self.reqs:
|
||||
if req.output_ids and len(req.output_ids) > 0:
|
||||
# resumed retracted req
|
||||
self.output_ids.append(req.output_ids[-1])
|
||||
else:
|
||||
assert req.transferred_output_id is not None
|
||||
req.output_ids.append(req.transferred_output_id)
|
||||
self.output_ids.append(req.transferred_output_id)
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||
|
||||
|
||||
class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
||||
|
||||
105
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
Normal file
105
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
|
||||
class ScheduleBatchDisaggregationDecodeMixin:
|
||||
|
||||
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
||||
"""
|
||||
Prepare a prebuilt extend by populate metadata
|
||||
Adapted from .prepare_for_extend().
|
||||
"""
|
||||
|
||||
self.forward_mode = ForwardMode.EXTEND
|
||||
reqs = self.reqs
|
||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||
seq_lens = []
|
||||
pre_lens = []
|
||||
req_pool_indices = []
|
||||
|
||||
# Pre-calculate total size
|
||||
total_size = sum(req.extend_input_len for req in reqs)
|
||||
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
||||
|
||||
# Fill the tensor in one pass
|
||||
offset = 0
|
||||
for i, req in enumerate(reqs):
|
||||
req_pool_indices.append(req.req_pool_idx)
|
||||
|
||||
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
: req.extend_input_len
|
||||
]
|
||||
assert (
|
||||
offset + req.extend_input_len <= total_size
|
||||
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
||||
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
||||
offset += req.extend_input_len
|
||||
|
||||
pre_len = len(req.prefix_indices)
|
||||
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
||||
seq_lens.append(seq_len)
|
||||
if len(req.output_ids) == 0:
|
||||
assert (
|
||||
seq_len - pre_len == req.extend_input_len
|
||||
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
||||
|
||||
req.cached_tokens += pre_len - req.already_computed
|
||||
req.already_computed = seq_len
|
||||
req.is_retracted = False
|
||||
pre_lens.append(pre_len)
|
||||
req.extend_logprob_start_len = 0
|
||||
|
||||
extend_input_logprob_token_ids = None
|
||||
|
||||
# Set fields
|
||||
self.input_ids = torch.tensor(
|
||||
sum(input_ids, []), dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.req_pool_indices = torch.tensor(
|
||||
req_pool_indices, dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||
|
||||
# Build sampling info
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
)
|
||||
|
||||
def process_prebuilt_extend(
|
||||
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
||||
):
|
||||
"""Assign the buffered last input id to schedule batch"""
|
||||
self.output_ids = []
|
||||
for req in self.reqs:
|
||||
if req.output_ids and len(req.output_ids) > 0:
|
||||
# resumed retracted req
|
||||
self.output_ids.append(req.output_ids[-1])
|
||||
else:
|
||||
assert req.transferred_output_id is not None
|
||||
req.output_ids.append(req.transferred_output_id)
|
||||
self.output_ids.append(req.transferred_output_id)
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import os
|
||||
import random
|
||||
import warnings
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
@@ -15,6 +17,9 @@ from sglang.srt.utils import get_ip
|
||||
|
||||
FakeBootstrapHost = "2.2.2.2"
|
||||
|
||||
# env var for testing failure, convert to float explicitly
|
||||
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
|
||||
|
||||
|
||||
class DisaggregationMode(Enum):
|
||||
NULL = "null"
|
||||
@@ -23,7 +28,16 @@ class DisaggregationMode(Enum):
|
||||
|
||||
|
||||
def poll_and_all_reduce(pollers, gloo_group):
|
||||
polls = [int(poller.poll()) for poller in pollers]
|
||||
# at a certain prob, the poll is failed to simulate failure
|
||||
if FAILURE_PROB > 0:
|
||||
from sglang.srt.disaggregation.base import KVPoll
|
||||
|
||||
polls = [
|
||||
int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll())
|
||||
for poller in pollers
|
||||
]
|
||||
else:
|
||||
polls = [int(poller.poll()) for poller in pollers]
|
||||
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
||||
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
|
||||
return tensor_to_reduce.tolist()
|
||||
|
||||
@@ -48,7 +48,9 @@ from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.disaggregation.base import BaseKVSender
|
||||
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
||||
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||
ScheduleBatchDisaggregationDecodeMixin,
|
||||
)
|
||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
|
||||
@@ -582,6 +582,8 @@ class Scheduler(
|
||||
gloo_group=self.attn_tp_cpu_group,
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
scheduler=self,
|
||||
tree_cache=self.tree_cache,
|
||||
)
|
||||
|
||||
# The decode requests pending for pre-allocation
|
||||
|
||||
@@ -38,7 +38,9 @@ class ChunkCache(BasePrefixCache):
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
req.req_pool_idx,
|
||||
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
|
||||
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
|
||||
]
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
|
||||
Reference in New Issue
Block a user