[bugfix] layerwise D first plan (#3866)

### What this PR does / why we need it?
Refactored the layerwise code to send to the D node first, preventing
P-node hangs due to communication timeouts when DP > 1.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
By ci

- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
wangxiaoteng888
2025-10-30 22:20:34 +08:00
committed by GitHub
parent 627f20ce26
commit 2c291bc63f
4 changed files with 963 additions and 1354 deletions

View File

@@ -88,18 +88,17 @@ import argparse
import asyncio import asyncio
import functools import functools
import heapq import heapq
import json
import os import os
import sys import sys
import threading import threading
import uuid import uuid
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass from typing import List
from typing import Any, List
import httpx import httpx
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -107,7 +106,6 @@ logger = init_logger(__name__)
# Add uvloop for faster event loop if available # Add uvloop for faster event loop if available
try: try:
import uvloop import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError: except ImportError:
pass pass
@@ -154,6 +152,9 @@ class ProxyState:
heapq.heapify(self.prefiller_heap) heapq.heapify(self.prefiller_heap)
heapq.heapify(self.decoder_heap) heapq.heapify(self.decoder_heap)
self.req_id_future = {} self.req_id_future = {}
self.req_data_dict = {}
self.tokenizer = AutoTokenizer.from_pretrained(
global_args.tokenizer_dir)
def _update_prefiller_priority(self, server_idx: int): def _update_prefiller_priority(self, server_idx: int):
"""Update the priority of a prefiller server in the heap.""" """Update the priority of a prefiller server in the heap."""
@@ -280,6 +281,10 @@ def parse_args():
nargs="+", nargs="+",
default=["localhost"]) default=["localhost"])
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002]) parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
parser.add_argument("--tokenizer-dir",
type=str,
default="/mnt/weight/Qwen3-235B-A22B-W8A8",
help="Maximum number of retries for HTTP requests")
parser.add_argument("--max-retries", parser.add_argument("--max-retries",
type=int, type=int,
default=3, default=3,
@@ -356,17 +361,6 @@ async def send_request_to_service(client: httpx.AsyncClient,
aborted_requests = proxy_state.aquire_aborted_prefiller_requests( aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
prefiller_id) prefiller_id)
req_data = req_data.copy() req_data = req_data.copy()
req_data['kv_transfer_params'] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
"aborted_request": list(aborted_requests),
"metaserver":
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
}
req_data["stream"] = False req_data["stream"] = False
req_data["max_tokens"] = 1 req_data["max_tokens"] = 1
if "stream_options" in req_data: if "stream_options" in req_data:
@@ -458,59 +452,11 @@ def get_api_request_id(api, req_id):
return "chatcmpl-" + req_id return "chatcmpl-" + req_id
async def _handle_select_instance(api: str, req_data: Any, def get_origin_request_id(api, req_id):
request_length: int): if api == "/completions":
prefiller_score = proxy_state.calculate_prefill_scores(request_length) return req_id.replace("cmpl-", "").replace("-0", "")
logger.debug( elif api == "/chat/completions":
f"Request length: {request_length}, Prefiller score: {prefiller_score}" return req_id.replace("chatcmpl-", "")
)
request_id = await proxy_state.next_req_id()
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
result_future = asyncio.Future() # type: ignore
request_id_api = get_api_request_id(api, request_id)
proxy_state.req_id_future[request_id_api] = result_future
# Send request to prefiller
asyncio.get_running_loop().create_task(
send_request_to_service(prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay))
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
response = await result_future
del proxy_state.req_id_future[request_id_api]
req_data["kv_transfer_params"] = response
# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
logger.debug("Decoder score: %f", decoder_score)
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
logger.debug("Using %s %s", prefiller.url, decoder.url)
return InstanceInfo(request_id=request_id,
prefiller_idx=prefiller_idx,
prefiller_score=prefiller_score,
prefiller=prefiller,
decoder=decoder,
decoder_idx=decoder_idx,
decoder_score=decoder_score)
@dataclass
class InstanceInfo:
request_id: str
prefiller_idx: int
prefiller_score: float
prefiller: ServerState
decoder_idx: int
decoder_score: float
decoder: ServerState
async def _handle_completions(api: str, request: Request): async def _handle_completions(api: str, request: Request):
@@ -518,120 +464,47 @@ async def _handle_completions(api: str, request: Request):
req_data = await request.json() req_data = await request.json()
req_body = await request.body() req_body = await request.body()
request_length = len(req_body) request_length = len(req_body)
instance_info = await _handle_select_instance(api, req_data, request_id = await proxy_state.next_req_id()
request_length) request_id_api = get_api_request_id(api, request_id)
stream_flag = bool(req_data.get("stream", False)) proxy_state.req_data_dict[request_id_api] = (req_data, request_length,
chat_flag = "messages" in req_data api)
req_data['kv_transfer_params'] = {
if "prompt" in req_data: "do_remote_decode":
origin_prompt = req_data["prompt"] False,
elif chat_flag: "do_remote_prefill":
messages = req_data["messages"] True,
origin_prompt = messages[0].get("content", "") "metaserver":
else: f"http://{global_args.host}:{global_args.port}/v1/metaserver"
origin_prompt = "" }
# refer to vLLM sampling_params: max_token default value # Select decoder
origin_max_tokens = req_data.get("max_tokens", 16) decoder_score = proxy_state.calculate_decode_scores(request_length)
logger.debug("Decoder score: %f", decoder_score)
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
# logger.debug("Using %s %s", prefiller.url, decoder.url)
# Stream response from decoder
released_kv = False
async def generate_stream(): async def generate_stream():
nonlocal instance_info nonlocal released_kv
generated_token = ""
released_kv = False
retry_count = 0
retry = True
completion_tokens = 0
# Only one await per chunk, minimal logic in loop # Only one await per chunk, minimal logic in loop
try: try:
while retry: async for chunk in stream_service_response_with_retry(
retry = False decoder.client,
async for chunk in stream_service_response_with_retry( api,
instance_info.decoder.client, req_data,
api, request_id=request_id,
req_data, max_retries=global_args.max_retries,
request_id=instance_info.request_id, base_delay=global_args.retry_delay):
max_retries=global_args.max_retries, yield chunk
base_delay=global_args.retry_delay):
if not released_kv and chunk:
proxy_state.release_prefiller_kv(
instance_info.prefiller_idx,
instance_info.prefiller_score)
released_kv = True
try:
chunk_str = chunk.decode("utf-8").strip()
except UnicodeDecodeError:
logger.debug(
f"Skipping chunk: {chunk}")
yield chunk
continue
if not chunk_str:
continue
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: "):]
try:
chunk_json = json.loads(chunk_str)
except json.JSONDecodeError:
# if chunk is [done], skip it.
logger.debug(
f"Skipping chunk: {chunk_str}")
yield chunk
continue
choices = chunk_json.get("choices", [])
if not choices:
yield chunk
continue
choice = choices[0]
delta = choice.get("delta") or {}
message = choice.get("message") or {}
content = (
delta.get("content")
or message.get("content")
or choice.get("text")
or ""
)
generated_token += content
stop_reason = choice.get(
"stop_reason")
usage = chunk_json.get("usage", {})
completion_tokens = (completion_tokens + 1) if stream_flag else \
(completion_tokens + usage.get("completion_tokens"))
if stop_reason == "recomputed":
retry = True
retry_count += 1
if chat_flag:
messages[0][
"content"] = origin_prompt + generated_token
else:
req_data[
"prompt"] = origin_prompt + generated_token
req_data[
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
tmp_request_length = len(
json.dumps(req_data).encode("utf-8"))
instance_info = await _handle_select_instance(
api, req_data, tmp_request_length)
break
if retry_count > 0 and not stream_flag:
if chat_flag:
choices[0]["message"][
"content"] = generated_token
else:
choices[0]["text"] = generated_token
chunk = json.dumps(chunk_json).encode("utf-8")
yield chunk
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it" f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
) )
proxy_state.abort_prefiller_request(
instance_info.prefiller_idx, instance_info.request_id)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
instance_info.prefiller_score)
# After streaming done, release tokens # After streaming done, release tokens
proxy_state.release_decoder(instance_info.decoder_idx, proxy_state.release_decoder(decoder_idx, decoder_score)
instance_info.decoder_score)
return StreamingResponse(generate_stream(), return StreamingResponse(generate_stream(),
media_type="application/json") media_type="application/json")
@@ -669,11 +542,33 @@ async def healthcheck():
@app.post("/v1/metaserver") @app.post("/v1/metaserver")
async def metaserver(request: Request): async def metaserver(request: Request):
try: try:
req_data = await request.json() kv_transfer_params = await request.json()
request_id = req_data.pop("request_id", None)
if request_id in proxy_state.req_id_future: request_id = kv_transfer_params["request_id"]
result_future = proxy_state.req_id_future[request_id] assert request_id in proxy_state.req_data_dict
result_future.set_result(req_data) req_data, request_length, api = proxy_state.req_data_dict[request_id]
request_id = get_origin_request_id(api, request_id)
req_data["kv_transfer_params"] = kv_transfer_params
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
logger.debug(f"Using prefill {prefiller.url=} {req_data=}")
# Send request to prefiller
response = await send_request_to_service(
prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay)
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
except Exception as e: except Exception as e:
logger.error(f"Post metaserver failed with: {str(e)}") logger.error(f"Post metaserver failed with: {str(e)}")
@@ -682,5 +577,4 @@ if __name__ == '__main__':
global global_args global global_args
global_args = parse_args() global_args = parse_args()
import uvicorn import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port) uvicorn.run(app, host=global_args.host, port=global_args.port)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -96,7 +96,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
parallel_config.data_parallel_size, num_head_replica, -1, parallel_config.data_parallel_size, num_head_replica, -1,
alltoall_group_size alltoall_group_size
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size] ) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
group_ranks = group_ranks.view(-1, alltoall_group_size).unbind(0) group_ranks = group_ranks.reshape(-1,
alltoall_group_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks] group_ranks = [x.tolist() for x in group_ranks]
local_rank = get_world_group().local_rank local_rank = get_world_group().local_rank
num = next( num = next(