[PD] Support prefill overlap + Ensure no race condition (#5609)
This commit is contained in:
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
@@ -204,6 +205,40 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
# Otherwise, it hangs under high concurrency
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap_disagg_prefill(self):
|
||||
self.result_queue = deque()
|
||||
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.process_input_requests(recv_reqs)
|
||||
self.waiting_queue.extend(
|
||||
self.disagg_prefill_pending_queue.pop_bootstrapped()
|
||||
)
|
||||
self.process_prefill_chunk()
|
||||
batch = self.get_new_batch_prefill()
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch:
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
|
||||
|
||||
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||
self.process_disagg_prefill_inflight_queue()
|
||||
|
||||
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
|
||||
self.check_memory()
|
||||
self.new_token_ratio = self.init_new_token_ratio
|
||||
|
||||
self.last_batch = batch
|
||||
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
|
||||
# Otherwise, it hangs under high concurrency
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
def process_batch_result_disagg_prefill(
|
||||
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
|
||||
) -> None:
|
||||
@@ -212,7 +247,26 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
Adapted from process_batch_result_prefill
|
||||
"""
|
||||
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req,
|
||||
bid,
|
||||
) = (
|
||||
result.logits_output,
|
||||
result.next_token_ids,
|
||||
result.extend_input_len_per_req,
|
||||
result.extend_logprob_start_len_per_req,
|
||||
result.bid,
|
||||
)
|
||||
|
||||
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||
if self.enable_overlap:
|
||||
# wait
|
||||
_, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||
else:
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
|
||||
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
|
||||
req: Req
|
||||
@@ -226,12 +280,8 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_chunked -= 1
|
||||
|
||||
# TODO: Not sure if this is necessary
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
# We need to remove this for overlap schedule.
|
||||
self.current_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
if self.enable_overlap:
|
||||
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
|
||||
|
||||
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
|
||||
"""
|
||||
@@ -276,20 +326,37 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
# only finished requests to running_batch.
|
||||
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
||||
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
||||
self.send_kv_chunk(self.chunked_req)
|
||||
if (
|
||||
self.enable_overlap
|
||||
): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
||||
self.chunked_req.tmp_end_idx = min(
|
||||
len(self.chunked_req.fill_ids),
|
||||
len(self.chunked_req.origin_input_ids),
|
||||
)
|
||||
else:
|
||||
self.send_kv_chunk(self.chunked_req)
|
||||
# chunked request keeps its rid but will get a new req_pool_idx
|
||||
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
||||
self.running_batch.batch_is_full = False
|
||||
|
||||
def send_kv_chunk(
|
||||
self: Scheduler, req: Req, token_id: Optional[int] = None
|
||||
self: Scheduler,
|
||||
req: Req,
|
||||
token_id: Optional[int] = None,
|
||||
end_idx: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Send a prefilled chunk to the decode server
|
||||
"""
|
||||
page_size = self.token_to_kv_pool_allocator.page_size
|
||||
start_idx = req.start_send_idx
|
||||
end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
|
||||
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
|
||||
# the resolved length is not the same as fill_ids's length
|
||||
end_idx = (
|
||||
end_idx
|
||||
if end_idx is not None
|
||||
else min(len(req.fill_ids), len(req.origin_input_ids))
|
||||
)
|
||||
last_chunk = token_id is not None
|
||||
|
||||
if (not last_chunk) and (
|
||||
@@ -302,7 +369,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
req.start_send_idx = end_idx
|
||||
|
||||
kv_indices = (
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
@@ -539,6 +539,11 @@ class Req:
|
||||
# The first output_id transferred from prefill instance.
|
||||
self.transferred_output_id: Optional[int] = None
|
||||
|
||||
# For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
|
||||
# This is because kv is not ready in `process_prefill_chunk`.
|
||||
# We use `tmp_end_idx` to store the end index of the kv cache to send.
|
||||
self.tmp_end_idx: int = -1
|
||||
|
||||
@property
|
||||
def seqlen(self):
|
||||
return len(self.origin_input_ids) + len(self.output_ids)
|
||||
|
||||
@@ -2014,7 +2014,10 @@ def run_scheduler_process(
|
||||
else:
|
||||
scheduler.event_loop_normal()
|
||||
elif disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
scheduler.event_loop_normal_disagg_prefill()
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap_disagg_prefill()
|
||||
else:
|
||||
scheduler.event_loop_normal_disagg_prefill()
|
||||
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||
if scheduler.enable_overlap:
|
||||
scheduler.event_loop_overlap_disagg_decode()
|
||||
|
||||
@@ -388,8 +388,6 @@ class ServerArgs:
|
||||
if self.disaggregation_mode == "prefill":
|
||||
self.disable_cuda_graph = True
|
||||
logger.warning("Cuda graph is disabled for prefill server")
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning("Overlap scheduler is disabled for prefill server")
|
||||
elif self.disaggregation_mode == "decode":
|
||||
self.disable_radix_cache = True
|
||||
logger.warning("KV cache is forced as chunk cache for decode server")
|
||||
|
||||
@@ -1,13 +1,29 @@
|
||||
prompt = [0] * 431
|
||||
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
prompt = """
|
||||
According to CNBC's Faber, the investors present on the call interpreted this statement as an indication of an upcoming funding round. While speculative, Faber believes the funding round could be as large as $25 billion, and bestow a valuation of between $150 billion and $200 billion on xAI.
|
||||
|
||||
For the benefit of those who might not be aware, xAI recently acquired the social media platform X in an all-stock deal that valued the former at $80 billion and the latter at $33 billion, inclusive of $12 billion in liabilities. This meant that the deal bestowed a gross valuation of $45 billion on X before factoring in its debt load of $12 billion.
|
||||
|
||||
Bear in mind that Elon Musk took X (then called Twitter) private back in 2022 in a $44 billion deal. Since then, Musk has managed to stem X's cash bleed, with the company reportedly generating $1.2 billion in adjusted EBITDA in 2024.
|
||||
|
||||
According to the investors present on the call, xAI is currently generating around $1 billion in annual revenue. This contrasts sharply with the erstwhile muted expectations of many investors, who did not expect the startup to generate any material revenue this year.
|
||||
|
||||
Elsewhere, Faber also alludes to the fact that xAI is already working on its next big training supercluster, officially dubbed the Colossus 2, which is expected to eventually house as many as 1 million NVIDIA GPUs at a cost of between $35 billion and $40 billion.
|
||||
|
||||
|
||||
Even though xAI's Grok LLM is already largely comparable with OpenAI's cutting-edge models, the Colossus 2 would significantly up the ante, and could feasibly challenge OpenAI's apex position in the AI sphere.
|
||||
|
||||
Give your honest take on the above text:
|
||||
"""
|
||||
|
||||
response = requests.post(
|
||||
"http://0.0.0.0:8000/generate",
|
||||
json={"input_ids": [prompt] * 32, "sampling_params": {"temperature": 0}},
|
||||
json={"text": prompt, "sampling_params": {"temperature": 0}},
|
||||
)
|
||||
|
||||
|
||||
# print("Response content (raw):", response.content)
|
||||
response_json = response.json()
|
||||
print(response_json["text"])
|
||||
|
||||
Reference in New Issue
Block a user