[bugfix_v0.11.0-dev] layerwise D first plan (#3907)

### What this PR does / why we need it?
Refactored the layerwise code to send to the D node first, preventing
P-node hangs due to communication timeouts when DP > 1.
---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: nwpu-zxr <zhouxuerong2@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
wangxiaoteng888
2025-10-30 22:21:11 +08:00
committed by GitHub
parent d5a9aba03f
commit af7a56550b
5 changed files with 965 additions and 1356 deletions

View File

@@ -88,18 +88,17 @@ import argparse
import asyncio
import functools
import heapq
import json
import os
import sys
import threading
import uuid
from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, List
from typing import List
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -107,7 +106,6 @@ logger = init_logger(__name__)
# Add uvloop for faster event loop if available
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
@@ -154,6 +152,9 @@ class ProxyState:
heapq.heapify(self.prefiller_heap)
heapq.heapify(self.decoder_heap)
self.req_id_future = {}
self.req_data_dict = {}
self.tokenizer = AutoTokenizer.from_pretrained(
global_args.tokenizer_dir)
def _update_prefiller_priority(self, server_idx: int):
"""Update the priority of a prefiller server in the heap."""
@@ -280,6 +281,10 @@ def parse_args():
nargs="+",
default=["localhost"])
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
parser.add_argument("--tokenizer-dir",
type=str,
default="/mnt/weight/Qwen3-235B-A22B-W8A8",
help="Maximum number of retries for HTTP requests")
parser.add_argument("--max-retries",
type=int,
default=3,
@@ -356,17 +361,6 @@ async def send_request_to_service(client: httpx.AsyncClient,
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
prefiller_id)
req_data = req_data.copy()
req_data['kv_transfer_params'] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
"aborted_request": list(aborted_requests),
"metaserver":
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
}
req_data["stream"] = False
req_data["max_tokens"] = 1
if "stream_options" in req_data:
@@ -458,59 +452,11 @@ def get_api_request_id(api, req_id):
return "chatcmpl-" + req_id
async def _handle_select_instance(api: str, req_data: Any,
request_length: int):
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
request_id = await proxy_state.next_req_id()
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
result_future = asyncio.Future() # type: ignore
request_id_api = get_api_request_id(api, request_id)
proxy_state.req_id_future[request_id_api] = result_future
# Send request to prefiller
asyncio.get_running_loop().create_task(
send_request_to_service(prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay))
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
response = await result_future
del proxy_state.req_id_future[request_id_api]
req_data["kv_transfer_params"] = response
# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
logger.debug("Decoder score: %f", decoder_score)
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
logger.debug("Using %s %s", prefiller.url, decoder.url)
return InstanceInfo(request_id=request_id,
prefiller_idx=prefiller_idx,
prefiller_score=prefiller_score,
prefiller=prefiller,
decoder=decoder,
decoder_idx=decoder_idx,
decoder_score=decoder_score)
@dataclass
class InstanceInfo:
request_id: str
prefiller_idx: int
prefiller_score: float
prefiller: ServerState
decoder_idx: int
decoder_score: float
decoder: ServerState
def get_origin_request_id(api, req_id):
if api == "/completions":
return req_id.replace("cmpl-", "").replace("-0", "")
elif api == "/chat/completions":
return req_id.replace("chatcmpl-", "")
async def _handle_completions(api: str, request: Request):
@@ -518,120 +464,47 @@ async def _handle_completions(api: str, request: Request):
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
instance_info = await _handle_select_instance(api, req_data,
request_length)
stream_flag = bool(req_data.get("stream", False))
chat_flag = "messages" in req_data
if "prompt" in req_data:
origin_prompt = req_data["prompt"]
elif chat_flag:
messages = req_data["messages"]
origin_prompt = messages[0].get("content", "")
else:
origin_prompt = ""
# refer to vLLM sampling_params: max_token default value
origin_max_tokens = req_data.get("max_tokens", 16)
request_id = await proxy_state.next_req_id()
request_id_api = get_api_request_id(api, request_id)
proxy_state.req_data_dict[request_id_api] = (req_data, request_length,
api)
req_data['kv_transfer_params'] = {
"do_remote_decode":
False,
"do_remote_prefill":
True,
"metaserver":
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
}
# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
logger.debug("Decoder score: %f", decoder_score)
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
# logger.debug("Using %s %s", prefiller.url, decoder.url)
# Stream response from decoder
released_kv = False
async def generate_stream():
nonlocal instance_info
generated_token = ""
released_kv = False
retry_count = 0
retry = True
completion_tokens = 0
nonlocal released_kv
# Only one await per chunk, minimal logic in loop
try:
while retry:
retry = False
async for chunk in stream_service_response_with_retry(
instance_info.decoder.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
if not released_kv and chunk:
proxy_state.release_prefiller_kv(
instance_info.prefiller_idx,
instance_info.prefiller_score)
released_kv = True
try:
chunk_str = chunk.decode("utf-8").strip()
except UnicodeDecodeError:
logger.debug(
f"Skipping chunk: {chunk}")
yield chunk
continue
if not chunk_str:
continue
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: "):]
try:
chunk_json = json.loads(chunk_str)
except json.JSONDecodeError:
# if chunk is [done], skip it.
logger.debug(
f"Skipping chunk: {chunk_str}")
yield chunk
continue
choices = chunk_json.get("choices", [])
if not choices:
yield chunk
continue
choice = choices[0]
delta = choice.get("delta") or {}
message = choice.get("message") or {}
content = (
delta.get("content")
or message.get("content")
or choice.get("text")
or ""
)
generated_token += content
stop_reason = choice.get(
"stop_reason")
usage = chunk_json.get("usage", {})
completion_tokens = (completion_tokens + 1) if stream_flag else \
(completion_tokens + usage.get("completion_tokens"))
if stop_reason == "recomputed":
retry = True
retry_count += 1
if chat_flag:
messages[0][
"content"] = origin_prompt + generated_token
else:
req_data[
"prompt"] = origin_prompt + generated_token
req_data[
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
tmp_request_length = len(
json.dumps(req_data).encode("utf-8"))
instance_info = await _handle_select_instance(
api, req_data, tmp_request_length)
break
if retry_count > 0 and not stream_flag:
if chat_flag:
choices[0]["message"][
"content"] = generated_token
else:
choices[0]["text"] = generated_token
chunk = json.dumps(chunk_json).encode("utf-8")
yield chunk
async for chunk in stream_service_response_with_retry(
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
)
proxy_state.abort_prefiller_request(
instance_info.prefiller_idx, instance_info.request_id)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
instance_info.prefiller_score)
# After streaming done, release tokens
proxy_state.release_decoder(instance_info.decoder_idx,
instance_info.decoder_score)
proxy_state.release_decoder(decoder_idx, decoder_score)
return StreamingResponse(generate_stream(),
media_type="application/json")
@@ -669,11 +542,33 @@ async def healthcheck():
@app.post("/v1/metaserver")
async def metaserver(request: Request):
try:
req_data = await request.json()
request_id = req_data.pop("request_id", None)
if request_id in proxy_state.req_id_future:
result_future = proxy_state.req_id_future[request_id]
result_future.set_result(req_data)
kv_transfer_params = await request.json()
request_id = kv_transfer_params["request_id"]
assert request_id in proxy_state.req_data_dict
req_data, request_length, api = proxy_state.req_data_dict[request_id]
request_id = get_origin_request_id(api, request_id)
req_data["kv_transfer_params"] = kv_transfer_params
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
logger.debug(f"Using prefill {prefiller.url=} {req_data=}")
# Send request to prefiller
response = await send_request_to_service(
prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay)
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
except Exception as e:
logger.error(f"Post metaserver failed with: {str(e)}")
@@ -682,5 +577,4 @@ if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@@ -1136,4 +1136,4 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
if __name__ == '__main__':
unittest.main()
unittest.main()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -89,7 +89,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
parallel_config.data_parallel_size, num_head_replica, -1,
alltoall_group_size
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
group_ranks = group_ranks.view(-1, alltoall_group_size).unbind(0)
group_ranks = group_ranks.reshape(-1,
alltoall_group_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
local_rank = get_world_group().local_rank
num = next(