diff --git a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py index d32ace3d..b9aec1ae 100644 --- a/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py +++ b/examples/disaggregated_prefill_v1/load_balance_proxy_layerwise_server_example.py @@ -86,9 +86,11 @@ import argparse import asyncio +import copy import functools import heapq import ipaddress +import json import os import sys import uuid @@ -412,7 +414,7 @@ async def _handle_completions(api: str, request: Request): request_length = len(req_body) request_id = await proxy_state.next_req_id() request_id_api = get_api_request_id(api, request_id) - proxy_state.req_data_dict[request_id_api] = (req_data, request_length, api) + proxy_state.req_data_dict[request_id_api] = (copy.deepcopy(req_data), request_length, api) req_data["kv_transfer_params"] = { "do_remote_decode": False, "do_remote_prefill": True, @@ -428,19 +430,91 @@ async def _handle_completions(api: str, request: Request): # Stream response from decoder released_kv = False + # Record request info for recompute + stream_flag = bool(req_data.get("stream", False)) + chat_flag = "messages" in req_data + if "prompt" in req_data: + origin_prompt = req_data["prompt"] + elif chat_flag: + messages = req_data["messages"] + origin_prompt = messages[0].get("content", "") + if isinstance(origin_prompt, list): + origin_prompt = origin_prompt[0].get("text", "") + else: + origin_prompt = "" + # refer to vLLM sampling_params: max_token default value + origin_max_tokens = req_data.get("max_tokens", 16) + async def generate_stream(): nonlocal released_kv + generated_token = "" + released_kv = False + retry_count = 0 + retry = True + completion_tokens = 0 # Only one await per chunk, minimal logic in loop try: - async for chunk in stream_service_response_with_retry( - decoder.client, - api, - req_data, - request_id=request_id, - max_retries=global_args.max_retries, - base_delay=global_args.retry_delay, - ): - yield chunk + while retry: + retry = False + async for chunk in stream_service_response_with_retry( + decoder.client, + api, + req_data, + request_id=request_id, + max_retries=global_args.max_retries, + base_delay=global_args.retry_delay, + ): + try: + chunk_str = chunk.decode("utf-8").strip() + except UnicodeDecodeError: + logger.debug(f"Skipping chunk: {chunk}") + yield chunk + continue + if not chunk_str: + continue + if chunk_str.startswith("data: "): + chunk_str = chunk_str[len("data: ") :] + try: + chunk_json = json.loads(chunk_str) + except json.JSONDecodeError: + # if chunk is [done], skip it. + logger.debug(f"Skipping chunk: {chunk_str}") + yield chunk + continue + choices = chunk_json.get("choices", []) + if not choices: + yield chunk + continue + + choice = choices[0] + delta = choice.get("delta") or {} + message = choice.get("message") or {} + content = delta.get("content") or message.get("content") or choice.get("text") or "" + generated_token += content + + stop_reason = choice.get("stop_reason") + usage = chunk_json.get("usage", {}) + completion_tokens = ( + (completion_tokens + 1) + if stream_flag + else (completion_tokens + usage.get("completion_tokens")) + ) + if stop_reason == "recomputed": + retry = True + retry_count += 1 + if chat_flag: + messages[0]["content"] = origin_prompt + generated_token + else: + req_data["prompt"] = origin_prompt + generated_token + req_data["max_tokens"] = origin_max_tokens - completion_tokens + retry_count + break + if retry_count > 0 and not stream_flag: + if chat_flag: + choice["message"]["content"] = generated_token + else: + choice["text"] = generated_token + chunk = json.dumps(chunk_json).encode("utf-8") + yield chunk except Exception as e: logger.error( f"Error during streaming from decoder {decoder.url}: {str(e)} " @@ -451,7 +525,10 @@ async def _handle_completions(api: str, request: Request): # After streaming done, release tokens proxy_state.release_decoder(decoder_idx, decoder_score) - return StreamingResponse(generate_stream(), media_type="application/json") + if stream_flag: + return StreamingResponse(generate_stream(), media_type="text/event-stream") + else: + return StreamingResponse(generate_stream(), media_type="application/json") except Exception as e: import traceback diff --git a/vllm_ascend/core/recompute_scheduler.py b/vllm_ascend/core/recompute_scheduler.py index 7208d91c..90ab413a 100644 --- a/vllm_ascend/core/recompute_scheduler.py +++ b/vllm_ascend/core/recompute_scheduler.py @@ -642,7 +642,7 @@ class RecomputeScheduler(Scheduler): EngineCoreOutput( request_id=req_info.request_id, finish_reason=FinishReason.STOP, - new_token_ids=[req_info.output_token_ids[-1]], + new_token_ids=[], stop_reason="recomputed", ) )