[P/D] layerwise connector support recompute scheduler (#5900)

### What this PR does / why we need it?
layerwise connector support recompute scheduler. 

NOTE:
Triggering recompute will invoke the tokenizer again, which may lead to
precision fluctuations.

[RFC]: CDCP Scheduling for Disaggregated Prefilling with KV Cache
Layerwise Push Support
https://github.com/vllm-project/vllm-ascend/issues/4842

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.13.0
- vLLM main:
bde38c11df

---------

Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
liziyu
2026-02-07 15:24:42 +08:00
committed by GitHub
parent d266fd7b47
commit e5f0e0eaf7
2 changed files with 89 additions and 12 deletions

View File

@@ -86,9 +86,11 @@
import argparse import argparse
import asyncio import asyncio
import copy
import functools import functools
import heapq import heapq
import ipaddress import ipaddress
import json
import os import os
import sys import sys
import uuid import uuid
@@ -412,7 +414,7 @@ async def _handle_completions(api: str, request: Request):
request_length = len(req_body) request_length = len(req_body)
request_id = await proxy_state.next_req_id() request_id = await proxy_state.next_req_id()
request_id_api = get_api_request_id(api, request_id) request_id_api = get_api_request_id(api, request_id)
proxy_state.req_data_dict[request_id_api] = (req_data, request_length, api) proxy_state.req_data_dict[request_id_api] = (copy.deepcopy(req_data), request_length, api)
req_data["kv_transfer_params"] = { req_data["kv_transfer_params"] = {
"do_remote_decode": False, "do_remote_decode": False,
"do_remote_prefill": True, "do_remote_prefill": True,
@@ -428,10 +430,32 @@ async def _handle_completions(api: str, request: Request):
# Stream response from decoder # Stream response from decoder
released_kv = False released_kv = False
# Record request info for recompute
stream_flag = bool(req_data.get("stream", False))
chat_flag = "messages" in req_data
if "prompt" in req_data:
origin_prompt = req_data["prompt"]
elif chat_flag:
messages = req_data["messages"]
origin_prompt = messages[0].get("content", "")
if isinstance(origin_prompt, list):
origin_prompt = origin_prompt[0].get("text", "")
else:
origin_prompt = ""
# refer to vLLM sampling_params: max_token default value
origin_max_tokens = req_data.get("max_tokens", 16)
async def generate_stream(): async def generate_stream():
nonlocal released_kv nonlocal released_kv
generated_token = ""
released_kv = False
retry_count = 0
retry = True
completion_tokens = 0
# Only one await per chunk, minimal logic in loop # Only one await per chunk, minimal logic in loop
try: try:
while retry:
retry = False
async for chunk in stream_service_response_with_retry( async for chunk in stream_service_response_with_retry(
decoder.client, decoder.client,
api, api,
@@ -440,6 +464,56 @@ async def _handle_completions(api: str, request: Request):
max_retries=global_args.max_retries, max_retries=global_args.max_retries,
base_delay=global_args.retry_delay, base_delay=global_args.retry_delay,
): ):
try:
chunk_str = chunk.decode("utf-8").strip()
except UnicodeDecodeError:
logger.debug(f"Skipping chunk: {chunk}")
yield chunk
continue
if not chunk_str:
continue
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: ") :]
try:
chunk_json = json.loads(chunk_str)
except json.JSONDecodeError:
# if chunk is [done], skip it.
logger.debug(f"Skipping chunk: {chunk_str}")
yield chunk
continue
choices = chunk_json.get("choices", [])
if not choices:
yield chunk
continue
choice = choices[0]
delta = choice.get("delta") or {}
message = choice.get("message") or {}
content = delta.get("content") or message.get("content") or choice.get("text") or ""
generated_token += content
stop_reason = choice.get("stop_reason")
usage = chunk_json.get("usage", {})
completion_tokens = (
(completion_tokens + 1)
if stream_flag
else (completion_tokens + usage.get("completion_tokens"))
)
if stop_reason == "recomputed":
retry = True
retry_count += 1
if chat_flag:
messages[0]["content"] = origin_prompt + generated_token
else:
req_data["prompt"] = origin_prompt + generated_token
req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count
break
if retry_count > 0 and not stream_flag:
if chat_flag:
choice["message"]["content"] = generated_token
else:
choice["text"] = generated_token
chunk = json.dumps(chunk_json).encode("utf-8")
yield chunk yield chunk
except Exception as e: except Exception as e:
logger.error( logger.error(
@@ -451,6 +525,9 @@ async def _handle_completions(api: str, request: Request):
# After streaming done, release tokens # After streaming done, release tokens
proxy_state.release_decoder(decoder_idx, decoder_score) proxy_state.release_decoder(decoder_idx, decoder_score)
if stream_flag:
return StreamingResponse(generate_stream(), media_type="text/event-stream")
else:
return StreamingResponse(generate_stream(), media_type="application/json") return StreamingResponse(generate_stream(), media_type="application/json")
except Exception as e: except Exception as e:
import traceback import traceback

View File

@@ -642,7 +642,7 @@ class RecomputeScheduler(Scheduler):
EngineCoreOutput( EngineCoreOutput(
request_id=req_info.request_id, request_id=req_info.request_id,
finish_reason=FinishReason.STOP, finish_reason=FinishReason.STOP,
new_token_ids=[req_info.output_token_ids[-1]], new_token_ids=[],
stop_reason="recomputed", stop_reason="recomputed",
) )
) )