[P/D] layerwise connector support recompute scheduler (#5900)
### What this PR does / why we need it?
layerwise connector support recompute scheduler.
NOTE:
Triggering recompute will invoke the tokenizer again, which may lead to
precision fluctuations.
[RFC]: CDCP Scheduling for Disaggregated Prefilling with KV Cache
Layerwise Push Support
https://github.com/vllm-project/vllm-ascend/issues/4842
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
bde38c11df
---------
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -86,9 +86,11 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import heapq
|
import heapq
|
||||||
import ipaddress
|
import ipaddress
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import uuid
|
import uuid
|
||||||
@@ -412,7 +414,7 @@ async def _handle_completions(api: str, request: Request):
|
|||||||
request_length = len(req_body)
|
request_length = len(req_body)
|
||||||
request_id = await proxy_state.next_req_id()
|
request_id = await proxy_state.next_req_id()
|
||||||
request_id_api = get_api_request_id(api, request_id)
|
request_id_api = get_api_request_id(api, request_id)
|
||||||
proxy_state.req_data_dict[request_id_api] = (req_data, request_length, api)
|
proxy_state.req_data_dict[request_id_api] = (copy.deepcopy(req_data), request_length, api)
|
||||||
req_data["kv_transfer_params"] = {
|
req_data["kv_transfer_params"] = {
|
||||||
"do_remote_decode": False,
|
"do_remote_decode": False,
|
||||||
"do_remote_prefill": True,
|
"do_remote_prefill": True,
|
||||||
@@ -428,19 +430,91 @@ async def _handle_completions(api: str, request: Request):
|
|||||||
# Stream response from decoder
|
# Stream response from decoder
|
||||||
released_kv = False
|
released_kv = False
|
||||||
|
|
||||||
|
# Record request info for recompute
|
||||||
|
stream_flag = bool(req_data.get("stream", False))
|
||||||
|
chat_flag = "messages" in req_data
|
||||||
|
if "prompt" in req_data:
|
||||||
|
origin_prompt = req_data["prompt"]
|
||||||
|
elif chat_flag:
|
||||||
|
messages = req_data["messages"]
|
||||||
|
origin_prompt = messages[0].get("content", "")
|
||||||
|
if isinstance(origin_prompt, list):
|
||||||
|
origin_prompt = origin_prompt[0].get("text", "")
|
||||||
|
else:
|
||||||
|
origin_prompt = ""
|
||||||
|
# refer to vLLM sampling_params: max_token default value
|
||||||
|
origin_max_tokens = req_data.get("max_tokens", 16)
|
||||||
|
|
||||||
async def generate_stream():
|
async def generate_stream():
|
||||||
nonlocal released_kv
|
nonlocal released_kv
|
||||||
|
generated_token = ""
|
||||||
|
released_kv = False
|
||||||
|
retry_count = 0
|
||||||
|
retry = True
|
||||||
|
completion_tokens = 0
|
||||||
# Only one await per chunk, minimal logic in loop
|
# Only one await per chunk, minimal logic in loop
|
||||||
try:
|
try:
|
||||||
async for chunk in stream_service_response_with_retry(
|
while retry:
|
||||||
decoder.client,
|
retry = False
|
||||||
api,
|
async for chunk in stream_service_response_with_retry(
|
||||||
req_data,
|
decoder.client,
|
||||||
request_id=request_id,
|
api,
|
||||||
max_retries=global_args.max_retries,
|
req_data,
|
||||||
base_delay=global_args.retry_delay,
|
request_id=request_id,
|
||||||
):
|
max_retries=global_args.max_retries,
|
||||||
yield chunk
|
base_delay=global_args.retry_delay,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
chunk_str = chunk.decode("utf-8").strip()
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
logger.debug(f"Skipping chunk: {chunk}")
|
||||||
|
yield chunk
|
||||||
|
continue
|
||||||
|
if not chunk_str:
|
||||||
|
continue
|
||||||
|
if chunk_str.startswith("data: "):
|
||||||
|
chunk_str = chunk_str[len("data: ") :]
|
||||||
|
try:
|
||||||
|
chunk_json = json.loads(chunk_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# if chunk is [done], skip it.
|
||||||
|
logger.debug(f"Skipping chunk: {chunk_str}")
|
||||||
|
yield chunk
|
||||||
|
continue
|
||||||
|
choices = chunk_json.get("choices", [])
|
||||||
|
if not choices:
|
||||||
|
yield chunk
|
||||||
|
continue
|
||||||
|
|
||||||
|
choice = choices[0]
|
||||||
|
delta = choice.get("delta") or {}
|
||||||
|
message = choice.get("message") or {}
|
||||||
|
content = delta.get("content") or message.get("content") or choice.get("text") or ""
|
||||||
|
generated_token += content
|
||||||
|
|
||||||
|
stop_reason = choice.get("stop_reason")
|
||||||
|
usage = chunk_json.get("usage", {})
|
||||||
|
completion_tokens = (
|
||||||
|
(completion_tokens + 1)
|
||||||
|
if stream_flag
|
||||||
|
else (completion_tokens + usage.get("completion_tokens"))
|
||||||
|
)
|
||||||
|
if stop_reason == "recomputed":
|
||||||
|
retry = True
|
||||||
|
retry_count += 1
|
||||||
|
if chat_flag:
|
||||||
|
messages[0]["content"] = origin_prompt + generated_token
|
||||||
|
else:
|
||||||
|
req_data["prompt"] = origin_prompt + generated_token
|
||||||
|
req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
||||||
|
break
|
||||||
|
if retry_count > 0 and not stream_flag:
|
||||||
|
if chat_flag:
|
||||||
|
choice["message"]["content"] = generated_token
|
||||||
|
else:
|
||||||
|
choice["text"] = generated_token
|
||||||
|
chunk = json.dumps(chunk_json).encode("utf-8")
|
||||||
|
yield chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error during streaming from decoder {decoder.url}: {str(e)} "
|
f"Error during streaming from decoder {decoder.url}: {str(e)} "
|
||||||
@@ -451,7 +525,10 @@ async def _handle_completions(api: str, request: Request):
|
|||||||
# After streaming done, release tokens
|
# After streaming done, release tokens
|
||||||
proxy_state.release_decoder(decoder_idx, decoder_score)
|
proxy_state.release_decoder(decoder_idx, decoder_score)
|
||||||
|
|
||||||
return StreamingResponse(generate_stream(), media_type="application/json")
|
if stream_flag:
|
||||||
|
return StreamingResponse(generate_stream(), media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
return StreamingResponse(generate_stream(), media_type="application/json")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
|||||||
@@ -642,7 +642,7 @@ class RecomputeScheduler(Scheduler):
|
|||||||
EngineCoreOutput(
|
EngineCoreOutput(
|
||||||
request_id=req_info.request_id,
|
request_id=req_info.request_id,
|
||||||
finish_reason=FinishReason.STOP,
|
finish_reason=FinishReason.STOP,
|
||||||
new_token_ids=[req_info.output_token_ids[-1]],
|
new_token_ids=[],
|
||||||
stop_reason="recomputed",
|
stop_reason="recomputed",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user