[P/D] layerwise connector support recompute scheduler (#5900)

### What this PR does / why we need it?
layerwise connector support recompute scheduler. 

NOTE:
Triggering recompute will invoke the tokenizer again, which may lead to
precision fluctuations.

[RFC]: CDCP Scheduling for Disaggregated Prefilling with KV Cache
Layerwise Push Support
https://github.com/vllm-project/vllm-ascend/issues/4842

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
bde38c11df

---------

Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
liziyu
2026-02-07 15:24:42 +08:00
committed by GitHub
parent d266fd7b47
commit e5f0e0eaf7
2 changed files with 89 additions and 12 deletions

View File

@@ -86,9 +86,11 @@
import argparse
import asyncio
import copy
import functools
import heapq
import ipaddress
import json
import os
import sys
import uuid
@@ -412,7 +414,7 @@ async def _handle_completions(api: str, request: Request):
request_length = len(req_body)
request_id = await proxy_state.next_req_id()
request_id_api = get_api_request_id(api, request_id)
proxy_state.req_data_dict[request_id_api] = (req_data, request_length, api)
proxy_state.req_data_dict[request_id_api] = (copy.deepcopy(req_data), request_length, api)
req_data["kv_transfer_params"] = {
"do_remote_decode": False,
"do_remote_prefill": True,
@@ -428,19 +430,91 @@ async def _handle_completions(api: str, request: Request):
# Stream response from decoder
released_kv = False
# Record request info for recompute
stream_flag = bool(req_data.get("stream", False))
chat_flag = "messages" in req_data
if "prompt" in req_data:
origin_prompt = req_data["prompt"]
elif chat_flag:
messages = req_data["messages"]
origin_prompt = messages[0].get("content", "")
if isinstance(origin_prompt, list):
origin_prompt = origin_prompt[0].get("text", "")
else:
origin_prompt = ""
# refer to vLLM sampling_params: max_token default value
origin_max_tokens = req_data.get("max_tokens", 16)
async def generate_stream():
nonlocal released_kv
generated_token = ""
released_kv = False
retry_count = 0
retry = True
completion_tokens = 0
# Only one await per chunk, minimal logic in loop
try:
async for chunk in stream_service_response_with_retry(
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay,
):
yield chunk
while retry:
retry = False
async for chunk in stream_service_response_with_retry(
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay,
):
try:
chunk_str = chunk.decode("utf-8").strip()
except UnicodeDecodeError:
logger.debug(f"Skipping chunk: {chunk}")
yield chunk
continue
if not chunk_str:
continue
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: ") :]
try:
chunk_json = json.loads(chunk_str)
except json.JSONDecodeError:
# if chunk is [done], skip it.
logger.debug(f"Skipping chunk: {chunk_str}")
yield chunk
continue
choices = chunk_json.get("choices", [])
if not choices:
yield chunk
continue
choice = choices[0]
delta = choice.get("delta") or {}
message = choice.get("message") or {}
content = delta.get("content") or message.get("content") or choice.get("text") or ""
generated_token += content
stop_reason = choice.get("stop_reason")
usage = chunk_json.get("usage", {})
completion_tokens = (
(completion_tokens + 1)
if stream_flag
else (completion_tokens + usage.get("completion_tokens"))
)
if stop_reason == "recomputed":
retry = True
retry_count += 1
if chat_flag:
messages[0]["content"] = origin_prompt + generated_token
else:
req_data["prompt"] = origin_prompt + generated_token
req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count
break
if retry_count > 0 and not stream_flag:
if chat_flag:
choice["message"]["content"] = generated_token
else:
choice["text"] = generated_token
chunk = json.dumps(chunk_json).encode("utf-8")
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {decoder.url}: {str(e)} "
@@ -451,7 +525,10 @@ async def _handle_completions(api: str, request: Request):
# After streaming done, release tokens
proxy_state.release_decoder(decoder_idx, decoder_score)
return StreamingResponse(generate_stream(), media_type="application/json")
if stream_flag:
return StreamingResponse(generate_stream(), media_type="text/event-stream")
else:
return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e:
import traceback

View File

@@ -642,7 +642,7 @@ class RecomputeScheduler(Scheduler):
EngineCoreOutput(
request_id=req_info.request_id,
finish_reason=FinishReason.STOP,
new_token_ids=[req_info.output_token_ids[-1]],
new_token_ids=[],
stop_reason="recomputed",
)
)