[PD] Fix failure abort (#6535)
This commit is contained in:
@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
poll_and_all_reduce,
|
poll_and_all_reduce,
|
||||||
prepare_abort,
|
prepare_abort,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
@@ -321,11 +322,15 @@ class DecodeTransferQueue:
|
|||||||
gloo_group: ProcessGroup,
|
gloo_group: ProcessGroup,
|
||||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||||
metadata_buffers: torch.Tensor,
|
metadata_buffers: torch.Tensor,
|
||||||
|
scheduler: Scheduler,
|
||||||
|
tree_cache: BasePrefixCache,
|
||||||
):
|
):
|
||||||
self.queue: List[DecodeRequest] = []
|
self.queue: List[DecodeRequest] = []
|
||||||
self.gloo_group = gloo_group
|
self.gloo_group = gloo_group
|
||||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||||
self.metadata_buffers = metadata_buffers
|
self.metadata_buffers = metadata_buffers
|
||||||
|
self.scheduler = scheduler
|
||||||
|
self.tree_cache = tree_cache
|
||||||
|
|
||||||
def add(self, req_conn: DecodeRequest) -> None:
|
def add(self, req_conn: DecodeRequest) -> None:
|
||||||
self.queue.append(req_conn)
|
self.queue.append(req_conn)
|
||||||
@@ -341,6 +346,14 @@ class DecodeTransferQueue:
|
|||||||
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# First, remove all failed requests from the queue
|
||||||
|
for i, decode_req in enumerate(self.queue):
|
||||||
|
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
||||||
|
self.scheduler.stream_output(
|
||||||
|
[decode_req.req], decode_req.req.return_logprob
|
||||||
|
)
|
||||||
|
indices_to_remove.add(i)
|
||||||
|
|
||||||
transferred_reqs = []
|
transferred_reqs = []
|
||||||
indices_to_remove = set()
|
indices_to_remove = set()
|
||||||
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||||
@@ -396,95 +409,6 @@ class DecodeTransferQueue:
|
|||||||
return transferred_reqs
|
return transferred_reqs
|
||||||
|
|
||||||
|
|
||||||
class ScheduleBatchDisaggregationDecodeMixin:
|
|
||||||
|
|
||||||
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
|
||||||
"""
|
|
||||||
Prepare a prebuilt extend by populate metadata
|
|
||||||
Adapted from .prepare_for_extend().
|
|
||||||
"""
|
|
||||||
|
|
||||||
self.forward_mode = ForwardMode.EXTEND
|
|
||||||
reqs = self.reqs
|
|
||||||
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
|
||||||
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
|
||||||
seq_lens = []
|
|
||||||
pre_lens = []
|
|
||||||
req_pool_indices = []
|
|
||||||
|
|
||||||
# Pre-calculate total size
|
|
||||||
total_size = sum(req.extend_input_len for req in reqs)
|
|
||||||
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
|
||||||
|
|
||||||
# Fill the tensor in one pass
|
|
||||||
offset = 0
|
|
||||||
for i, req in enumerate(reqs):
|
|
||||||
req_pool_indices.append(req.req_pool_idx)
|
|
||||||
|
|
||||||
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
|
||||||
: req.extend_input_len
|
|
||||||
]
|
|
||||||
assert (
|
|
||||||
offset + req.extend_input_len <= total_size
|
|
||||||
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
|
||||||
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
|
||||||
offset += req.extend_input_len
|
|
||||||
|
|
||||||
pre_len = len(req.prefix_indices)
|
|
||||||
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
|
||||||
seq_lens.append(seq_len)
|
|
||||||
if len(req.output_ids) == 0:
|
|
||||||
assert (
|
|
||||||
seq_len - pre_len == req.extend_input_len
|
|
||||||
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
|
||||||
|
|
||||||
req.cached_tokens += pre_len - req.already_computed
|
|
||||||
req.already_computed = seq_len
|
|
||||||
req.is_retracted = False
|
|
||||||
pre_lens.append(pre_len)
|
|
||||||
req.extend_logprob_start_len = 0
|
|
||||||
|
|
||||||
extend_input_logprob_token_ids = None
|
|
||||||
|
|
||||||
# Set fields
|
|
||||||
self.input_ids = torch.tensor(
|
|
||||||
sum(input_ids, []), dtype=torch.int32, device=self.device
|
|
||||||
)
|
|
||||||
self.req_pool_indices = torch.tensor(
|
|
||||||
req_pool_indices, dtype=torch.int64, device=self.device
|
|
||||||
)
|
|
||||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
|
||||||
self.out_cache_loc = out_cache_loc
|
|
||||||
self.seq_lens_sum = sum(seq_lens)
|
|
||||||
self.extend_num_tokens = extend_num_tokens
|
|
||||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
|
||||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
|
||||||
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
|
||||||
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
|
||||||
|
|
||||||
# Build sampling info
|
|
||||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
|
||||||
self,
|
|
||||||
self.model_config.vocab_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_prebuilt_extend(
|
|
||||||
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
|
||||||
):
|
|
||||||
"""Assign the buffered last input id to schedule batch"""
|
|
||||||
self.output_ids = []
|
|
||||||
for req in self.reqs:
|
|
||||||
if req.output_ids and len(req.output_ids) > 0:
|
|
||||||
# resumed retracted req
|
|
||||||
self.output_ids.append(req.output_ids[-1])
|
|
||||||
else:
|
|
||||||
assert req.transferred_output_id is not None
|
|
||||||
req.output_ids.append(req.transferred_output_id)
|
|
||||||
self.output_ids.append(req.transferred_output_id)
|
|
||||||
self.tree_cache.cache_unfinished_req(req)
|
|
||||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulerDisaggregationDecodeMixin:
|
class SchedulerDisaggregationDecodeMixin:
|
||||||
|
|
||||||
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
def _prepare_idle_batch_and_run(self, batch, delay_process=False):
|
||||||
|
|||||||
105
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
Normal file
105
python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
|
|
||||||
|
class ScheduleBatchDisaggregationDecodeMixin:
|
||||||
|
|
||||||
|
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
||||||
|
"""
|
||||||
|
Prepare a prebuilt extend by populate metadata
|
||||||
|
Adapted from .prepare_for_extend().
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.forward_mode = ForwardMode.EXTEND
|
||||||
|
reqs = self.reqs
|
||||||
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
||||||
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
||||||
|
seq_lens = []
|
||||||
|
pre_lens = []
|
||||||
|
req_pool_indices = []
|
||||||
|
|
||||||
|
# Pre-calculate total size
|
||||||
|
total_size = sum(req.extend_input_len for req in reqs)
|
||||||
|
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
||||||
|
|
||||||
|
# Fill the tensor in one pass
|
||||||
|
offset = 0
|
||||||
|
for i, req in enumerate(reqs):
|
||||||
|
req_pool_indices.append(req.req_pool_idx)
|
||||||
|
|
||||||
|
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||||
|
: req.extend_input_len
|
||||||
|
]
|
||||||
|
assert (
|
||||||
|
offset + req.extend_input_len <= total_size
|
||||||
|
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
||||||
|
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
||||||
|
offset += req.extend_input_len
|
||||||
|
|
||||||
|
pre_len = len(req.prefix_indices)
|
||||||
|
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
||||||
|
seq_lens.append(seq_len)
|
||||||
|
if len(req.output_ids) == 0:
|
||||||
|
assert (
|
||||||
|
seq_len - pre_len == req.extend_input_len
|
||||||
|
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
||||||
|
|
||||||
|
req.cached_tokens += pre_len - req.already_computed
|
||||||
|
req.already_computed = seq_len
|
||||||
|
req.is_retracted = False
|
||||||
|
pre_lens.append(pre_len)
|
||||||
|
req.extend_logprob_start_len = 0
|
||||||
|
|
||||||
|
extend_input_logprob_token_ids = None
|
||||||
|
|
||||||
|
# Set fields
|
||||||
|
self.input_ids = torch.tensor(
|
||||||
|
sum(input_ids, []), dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
|
self.req_pool_indices = torch.tensor(
|
||||||
|
req_pool_indices, dtype=torch.int64, device=self.device
|
||||||
|
)
|
||||||
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
||||||
|
self.out_cache_loc = out_cache_loc
|
||||||
|
self.seq_lens_sum = sum(seq_lens)
|
||||||
|
self.extend_num_tokens = extend_num_tokens
|
||||||
|
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||||
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||||
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
||||||
|
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
||||||
|
|
||||||
|
# Build sampling info
|
||||||
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||||
|
self,
|
||||||
|
self.model_config.vocab_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_prebuilt_extend(
|
||||||
|
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
||||||
|
):
|
||||||
|
"""Assign the buffered last input id to schedule batch"""
|
||||||
|
self.output_ids = []
|
||||||
|
for req in self.reqs:
|
||||||
|
if req.output_ids and len(req.output_ids) > 0:
|
||||||
|
# resumed retracted req
|
||||||
|
self.output_ids.append(req.output_ids[-1])
|
||||||
|
else:
|
||||||
|
assert req.transferred_output_id is not None
|
||||||
|
req.output_ids.append(req.transferred_output_id)
|
||||||
|
self.output_ids.append(req.transferred_output_id)
|
||||||
|
self.tree_cache.cache_unfinished_req(req)
|
||||||
|
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import os
|
||||||
|
import random
|
||||||
import warnings
|
import warnings
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
@@ -15,6 +17,9 @@ from sglang.srt.utils import get_ip
|
|||||||
|
|
||||||
FakeBootstrapHost = "2.2.2.2"
|
FakeBootstrapHost = "2.2.2.2"
|
||||||
|
|
||||||
|
# env var for testing failure, convert to float explicitly
|
||||||
|
FAILURE_PROB = float(os.getenv("DISAGGREGATION_TEST_FAILURE_PROB", 0))
|
||||||
|
|
||||||
|
|
||||||
class DisaggregationMode(Enum):
|
class DisaggregationMode(Enum):
|
||||||
NULL = "null"
|
NULL = "null"
|
||||||
@@ -23,6 +28,15 @@ class DisaggregationMode(Enum):
|
|||||||
|
|
||||||
|
|
||||||
def poll_and_all_reduce(pollers, gloo_group):
|
def poll_and_all_reduce(pollers, gloo_group):
|
||||||
|
# at a certain prob, the poll is failed to simulate failure
|
||||||
|
if FAILURE_PROB > 0:
|
||||||
|
from sglang.srt.disaggregation.base import KVPoll
|
||||||
|
|
||||||
|
polls = [
|
||||||
|
int(KVPoll.Failed) if random.random() < FAILURE_PROB else int(poller.poll())
|
||||||
|
for poller in pollers
|
||||||
|
]
|
||||||
|
else:
|
||||||
polls = [int(poller.poll()) for poller in pollers]
|
polls = [int(poller.poll()) for poller in pollers]
|
||||||
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
tensor_to_reduce = torch.tensor(polls, dtype=torch.uint8, device="cpu")
|
||||||
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
|
dist.all_reduce(tensor_to_reduce, op=dist.ReduceOp.MIN, group=gloo_group)
|
||||||
|
|||||||
@@ -48,7 +48,9 @@ from sglang.global_config import global_config
|
|||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||||
from sglang.srt.disaggregation.base import BaseKVSender
|
from sglang.srt.disaggregation.base import BaseKVSender
|
||||||
from sglang.srt.disaggregation.decode import ScheduleBatchDisaggregationDecodeMixin
|
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
||||||
|
ScheduleBatchDisaggregationDecodeMixin,
|
||||||
|
)
|
||||||
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
from sglang.srt.layers.multimodal import gpu_tensor_hash
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
|
|||||||
@@ -582,6 +582,8 @@ class Scheduler(
|
|||||||
gloo_group=self.attn_tp_cpu_group,
|
gloo_group=self.attn_tp_cpu_group,
|
||||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||||
metadata_buffers=metadata_buffers,
|
metadata_buffers=metadata_buffers,
|
||||||
|
scheduler=self,
|
||||||
|
tree_cache=self.tree_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The decode requests pending for pre-allocation
|
# The decode requests pending for pre-allocation
|
||||||
|
|||||||
@@ -38,7 +38,9 @@ class ChunkCache(BasePrefixCache):
|
|||||||
|
|
||||||
def cache_finished_req(self, req: Req):
|
def cache_finished_req(self, req: Req):
|
||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : len(req.origin_input_ids) + len(req.output_ids) - 1
|
req.req_pool_idx,
|
||||||
|
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
|
||||||
|
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
|
||||||
]
|
]
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||||
|
|||||||
Reference in New Issue
Block a user