[P/D] layerwise connector support recompute scheduler (#5900)
### What this PR does / why we need it?
layerwise connector support recompute scheduler.
NOTE:
Triggering recompute will invoke the tokenizer again, which may lead to
precision fluctuations.
[RFC]: CDCP Scheduling for Disaggregated Prefilling with KV Cache
Layerwise Push Support
https://github.com/vllm-project/vllm-ascend/issues/4842
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
---------
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -86,9 +86,11 @@
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import functools
|
||||
import heapq
|
||||
import ipaddress
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
@@ -412,7 +414,7 @@ async def _handle_completions(api: str, request: Request):
|
||||
request_length = len(req_body)
|
||||
request_id = await proxy_state.next_req_id()
|
||||
request_id_api = get_api_request_id(api, request_id)
|
||||
proxy_state.req_data_dict[request_id_api] = (req_data, request_length, api)
|
||||
proxy_state.req_data_dict[request_id_api] = (copy.deepcopy(req_data), request_length, api)
|
||||
req_data["kv_transfer_params"] = {
|
||||
"do_remote_decode": False,
|
||||
"do_remote_prefill": True,
|
||||
@@ -428,19 +430,91 @@ async def _handle_completions(api: str, request: Request):
|
||||
# Stream response from decoder
|
||||
released_kv = False
|
||||
|
||||
# Record request info for recompute
|
||||
stream_flag = bool(req_data.get("stream", False))
|
||||
chat_flag = "messages" in req_data
|
||||
if "prompt" in req_data:
|
||||
origin_prompt = req_data["prompt"]
|
||||
elif chat_flag:
|
||||
messages = req_data["messages"]
|
||||
origin_prompt = messages[0].get("content", "")
|
||||
if isinstance(origin_prompt, list):
|
||||
origin_prompt = origin_prompt[0].get("text", "")
|
||||
else:
|
||||
origin_prompt = ""
|
||||
# refer to vLLM sampling_params: max_token default value
|
||||
origin_max_tokens = req_data.get("max_tokens", 16)
|
||||
|
||||
async def generate_stream():
|
||||
nonlocal released_kv
|
||||
generated_token = ""
|
||||
released_kv = False
|
||||
retry_count = 0
|
||||
retry = True
|
||||
completion_tokens = 0
|
||||
# Only one await per chunk, minimal logic in loop
|
||||
try:
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay,
|
||||
):
|
||||
yield chunk
|
||||
while retry:
|
||||
retry = False
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay,
|
||||
):
|
||||
try:
|
||||
chunk_str = chunk.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
logger.debug(f"Skipping chunk: {chunk}")
|
||||
yield chunk
|
||||
continue
|
||||
if not chunk_str:
|
||||
continue
|
||||
if chunk_str.startswith("data: "):
|
||||
chunk_str = chunk_str[len("data: ") :]
|
||||
try:
|
||||
chunk_json = json.loads(chunk_str)
|
||||
except json.JSONDecodeError:
|
||||
# if chunk is [done], skip it.
|
||||
logger.debug(f"Skipping chunk: {chunk_str}")
|
||||
yield chunk
|
||||
continue
|
||||
choices = chunk_json.get("choices", [])
|
||||
if not choices:
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
choice = choices[0]
|
||||
delta = choice.get("delta") or {}
|
||||
message = choice.get("message") or {}
|
||||
content = delta.get("content") or message.get("content") or choice.get("text") or ""
|
||||
generated_token += content
|
||||
|
||||
stop_reason = choice.get("stop_reason")
|
||||
usage = chunk_json.get("usage", {})
|
||||
completion_tokens = (
|
||||
(completion_tokens + 1)
|
||||
if stream_flag
|
||||
else (completion_tokens + usage.get("completion_tokens"))
|
||||
)
|
||||
if stop_reason == "recomputed":
|
||||
retry = True
|
||||
retry_count += 1
|
||||
if chat_flag:
|
||||
messages[0]["content"] = origin_prompt + generated_token
|
||||
else:
|
||||
req_data["prompt"] = origin_prompt + generated_token
|
||||
req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
||||
break
|
||||
if retry_count > 0 and not stream_flag:
|
||||
if chat_flag:
|
||||
choice["message"]["content"] = generated_token
|
||||
else:
|
||||
choice["text"] = generated_token
|
||||
chunk = json.dumps(chunk_json).encode("utf-8")
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during streaming from decoder {decoder.url}: {str(e)} "
|
||||
@@ -451,7 +525,10 @@ async def _handle_completions(api: str, request: Request):
|
||||
# After streaming done, release tokens
|
||||
proxy_state.release_decoder(decoder_idx, decoder_score)
|
||||
|
||||
return StreamingResponse(generate_stream(), media_type="application/json")
|
||||
if stream_flag:
|
||||
return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
||||
else:
|
||||
return StreamingResponse(generate_stream(), media_type="application/json")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
|
||||
@@ -642,7 +642,7 @@ class RecomputeScheduler(Scheduler):
|
||||
EngineCoreOutput(
|
||||
request_id=req_info.request_id,
|
||||
finish_reason=FinishReason.STOP,
|
||||
new_token_ids=[req_info.output_token_ids[-1]],
|
||||
new_token_ids=[],
|
||||
stop_reason="recomputed",
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user