[bugfix] layerwise D first plan (#3866)
### What this PR does / why we need it?
Refactored the layerwise code to send to the D node first, preventing
P-node hangs due to communication timeouts when DP > 1.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
By ci
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
@@ -88,18 +88,17 @@ import argparse
|
|||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import heapq
|
import heapq
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from typing import List
|
||||||
from typing import Any, List
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from transformers import AutoTokenizer
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -107,7 +106,6 @@ logger = init_logger(__name__)
|
|||||||
# Add uvloop for faster event loop if available
|
# Add uvloop for faster event loop if available
|
||||||
try:
|
try:
|
||||||
import uvloop
|
import uvloop
|
||||||
|
|
||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
@@ -154,6 +152,9 @@ class ProxyState:
|
|||||||
heapq.heapify(self.prefiller_heap)
|
heapq.heapify(self.prefiller_heap)
|
||||||
heapq.heapify(self.decoder_heap)
|
heapq.heapify(self.decoder_heap)
|
||||||
self.req_id_future = {}
|
self.req_id_future = {}
|
||||||
|
self.req_data_dict = {}
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
global_args.tokenizer_dir)
|
||||||
|
|
||||||
def _update_prefiller_priority(self, server_idx: int):
|
def _update_prefiller_priority(self, server_idx: int):
|
||||||
"""Update the priority of a prefiller server in the heap."""
|
"""Update the priority of a prefiller server in the heap."""
|
||||||
@@ -280,6 +281,10 @@ def parse_args():
|
|||||||
nargs="+",
|
nargs="+",
|
||||||
default=["localhost"])
|
default=["localhost"])
|
||||||
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
||||||
|
parser.add_argument("--tokenizer-dir",
|
||||||
|
type=str,
|
||||||
|
default="/mnt/weight/Qwen3-235B-A22B-W8A8",
|
||||||
|
help="Maximum number of retries for HTTP requests")
|
||||||
parser.add_argument("--max-retries",
|
parser.add_argument("--max-retries",
|
||||||
type=int,
|
type=int,
|
||||||
default=3,
|
default=3,
|
||||||
@@ -356,17 +361,6 @@ async def send_request_to_service(client: httpx.AsyncClient,
|
|||||||
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
|
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
|
||||||
prefiller_id)
|
prefiller_id)
|
||||||
req_data = req_data.copy()
|
req_data = req_data.copy()
|
||||||
req_data['kv_transfer_params'] = {
|
|
||||||
"do_remote_decode": True,
|
|
||||||
"do_remote_prefill": False,
|
|
||||||
"remote_engine_id": None,
|
|
||||||
"remote_block_ids": None,
|
|
||||||
"remote_host": None,
|
|
||||||
"remote_port": None,
|
|
||||||
"aborted_request": list(aborted_requests),
|
|
||||||
"metaserver":
|
|
||||||
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
|
|
||||||
}
|
|
||||||
req_data["stream"] = False
|
req_data["stream"] = False
|
||||||
req_data["max_tokens"] = 1
|
req_data["max_tokens"] = 1
|
||||||
if "stream_options" in req_data:
|
if "stream_options" in req_data:
|
||||||
@@ -458,59 +452,11 @@ def get_api_request_id(api, req_id):
|
|||||||
return "chatcmpl-" + req_id
|
return "chatcmpl-" + req_id
|
||||||
|
|
||||||
|
|
||||||
async def _handle_select_instance(api: str, req_data: Any,
|
def get_origin_request_id(api, req_id):
|
||||||
request_length: int):
|
if api == "/completions":
|
||||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
return req_id.replace("cmpl-", "").replace("-0", "")
|
||||||
logger.debug(
|
elif api == "/chat/completions":
|
||||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
return req_id.replace("chatcmpl-", "")
|
||||||
)
|
|
||||||
request_id = await proxy_state.next_req_id()
|
|
||||||
# Select prefiller
|
|
||||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
|
||||||
prefiller = proxy_state.prefillers[prefiller_idx]
|
|
||||||
result_future = asyncio.Future() # type: ignore
|
|
||||||
request_id_api = get_api_request_id(api, request_id)
|
|
||||||
proxy_state.req_id_future[request_id_api] = result_future
|
|
||||||
# Send request to prefiller
|
|
||||||
asyncio.get_running_loop().create_task(
|
|
||||||
send_request_to_service(prefiller.client,
|
|
||||||
prefiller_idx,
|
|
||||||
api,
|
|
||||||
req_data,
|
|
||||||
request_id,
|
|
||||||
max_retries=global_args.max_retries,
|
|
||||||
base_delay=global_args.retry_delay))
|
|
||||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
|
||||||
|
|
||||||
response = await result_future
|
|
||||||
del proxy_state.req_id_future[request_id_api]
|
|
||||||
req_data["kv_transfer_params"] = response
|
|
||||||
|
|
||||||
# Select decoder
|
|
||||||
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
|
||||||
logger.debug("Decoder score: %f", decoder_score)
|
|
||||||
# Use the prefiller's kv_transfer_params to select decoder
|
|
||||||
decoder_idx = proxy_state.select_decoder(decoder_score)
|
|
||||||
decoder = proxy_state.decoders[decoder_idx]
|
|
||||||
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
|
||||||
return InstanceInfo(request_id=request_id,
|
|
||||||
prefiller_idx=prefiller_idx,
|
|
||||||
prefiller_score=prefiller_score,
|
|
||||||
prefiller=prefiller,
|
|
||||||
decoder=decoder,
|
|
||||||
decoder_idx=decoder_idx,
|
|
||||||
decoder_score=decoder_score)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InstanceInfo:
|
|
||||||
request_id: str
|
|
||||||
prefiller_idx: int
|
|
||||||
prefiller_score: float
|
|
||||||
prefiller: ServerState
|
|
||||||
decoder_idx: int
|
|
||||||
decoder_score: float
|
|
||||||
decoder: ServerState
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_completions(api: str, request: Request):
|
async def _handle_completions(api: str, request: Request):
|
||||||
@@ -518,120 +464,47 @@ async def _handle_completions(api: str, request: Request):
|
|||||||
req_data = await request.json()
|
req_data = await request.json()
|
||||||
req_body = await request.body()
|
req_body = await request.body()
|
||||||
request_length = len(req_body)
|
request_length = len(req_body)
|
||||||
instance_info = await _handle_select_instance(api, req_data,
|
request_id = await proxy_state.next_req_id()
|
||||||
request_length)
|
request_id_api = get_api_request_id(api, request_id)
|
||||||
stream_flag = bool(req_data.get("stream", False))
|
proxy_state.req_data_dict[request_id_api] = (req_data, request_length,
|
||||||
chat_flag = "messages" in req_data
|
api)
|
||||||
|
req_data['kv_transfer_params'] = {
|
||||||
if "prompt" in req_data:
|
"do_remote_decode":
|
||||||
origin_prompt = req_data["prompt"]
|
False,
|
||||||
elif chat_flag:
|
"do_remote_prefill":
|
||||||
messages = req_data["messages"]
|
True,
|
||||||
origin_prompt = messages[0].get("content", "")
|
"metaserver":
|
||||||
else:
|
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
|
||||||
origin_prompt = ""
|
}
|
||||||
# refer to vLLM sampling_params: max_token default value
|
# Select decoder
|
||||||
origin_max_tokens = req_data.get("max_tokens", 16)
|
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
||||||
|
logger.debug("Decoder score: %f", decoder_score)
|
||||||
|
# Use the prefiller's kv_transfer_params to select decoder
|
||||||
|
decoder_idx = proxy_state.select_decoder(decoder_score)
|
||||||
|
decoder = proxy_state.decoders[decoder_idx]
|
||||||
|
# logger.debug("Using %s %s", prefiller.url, decoder.url)
|
||||||
|
# Stream response from decoder
|
||||||
|
released_kv = False
|
||||||
|
|
||||||
async def generate_stream():
|
async def generate_stream():
|
||||||
nonlocal instance_info
|
nonlocal released_kv
|
||||||
generated_token = ""
|
|
||||||
released_kv = False
|
|
||||||
retry_count = 0
|
|
||||||
retry = True
|
|
||||||
completion_tokens = 0
|
|
||||||
# Only one await per chunk, minimal logic in loop
|
# Only one await per chunk, minimal logic in loop
|
||||||
try:
|
try:
|
||||||
while retry:
|
async for chunk in stream_service_response_with_retry(
|
||||||
retry = False
|
decoder.client,
|
||||||
async for chunk in stream_service_response_with_retry(
|
api,
|
||||||
instance_info.decoder.client,
|
req_data,
|
||||||
api,
|
request_id=request_id,
|
||||||
req_data,
|
max_retries=global_args.max_retries,
|
||||||
request_id=instance_info.request_id,
|
base_delay=global_args.retry_delay):
|
||||||
max_retries=global_args.max_retries,
|
yield chunk
|
||||||
base_delay=global_args.retry_delay):
|
|
||||||
if not released_kv and chunk:
|
|
||||||
proxy_state.release_prefiller_kv(
|
|
||||||
instance_info.prefiller_idx,
|
|
||||||
instance_info.prefiller_score)
|
|
||||||
released_kv = True
|
|
||||||
try:
|
|
||||||
chunk_str = chunk.decode("utf-8").strip()
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
logger.debug(
|
|
||||||
f"Skipping chunk: {chunk}")
|
|
||||||
yield chunk
|
|
||||||
continue
|
|
||||||
if not chunk_str:
|
|
||||||
continue
|
|
||||||
if chunk_str.startswith("data: "):
|
|
||||||
chunk_str = chunk_str[len("data: "):]
|
|
||||||
try:
|
|
||||||
chunk_json = json.loads(chunk_str)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# if chunk is [done], skip it.
|
|
||||||
logger.debug(
|
|
||||||
f"Skipping chunk: {chunk_str}")
|
|
||||||
yield chunk
|
|
||||||
continue
|
|
||||||
choices = chunk_json.get("choices", [])
|
|
||||||
if not choices:
|
|
||||||
yield chunk
|
|
||||||
continue
|
|
||||||
|
|
||||||
choice = choices[0]
|
|
||||||
delta = choice.get("delta") or {}
|
|
||||||
message = choice.get("message") or {}
|
|
||||||
content = (
|
|
||||||
delta.get("content")
|
|
||||||
or message.get("content")
|
|
||||||
or choice.get("text")
|
|
||||||
or ""
|
|
||||||
)
|
|
||||||
generated_token += content
|
|
||||||
|
|
||||||
stop_reason = choice.get(
|
|
||||||
"stop_reason")
|
|
||||||
usage = chunk_json.get("usage", {})
|
|
||||||
completion_tokens = (completion_tokens + 1) if stream_flag else \
|
|
||||||
(completion_tokens + usage.get("completion_tokens"))
|
|
||||||
if stop_reason == "recomputed":
|
|
||||||
retry = True
|
|
||||||
retry_count += 1
|
|
||||||
if chat_flag:
|
|
||||||
messages[0][
|
|
||||||
"content"] = origin_prompt + generated_token
|
|
||||||
else:
|
|
||||||
req_data[
|
|
||||||
"prompt"] = origin_prompt + generated_token
|
|
||||||
req_data[
|
|
||||||
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
|
||||||
tmp_request_length = len(
|
|
||||||
json.dumps(req_data).encode("utf-8"))
|
|
||||||
instance_info = await _handle_select_instance(
|
|
||||||
api, req_data, tmp_request_length)
|
|
||||||
break
|
|
||||||
if retry_count > 0 and not stream_flag:
|
|
||||||
if chat_flag:
|
|
||||||
choices[0]["message"][
|
|
||||||
"content"] = generated_token
|
|
||||||
else:
|
|
||||||
choices[0]["text"] = generated_token
|
|
||||||
chunk = json.dumps(chunk_json).encode("utf-8")
|
|
||||||
yield chunk
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
||||||
)
|
)
|
||||||
proxy_state.abort_prefiller_request(
|
|
||||||
instance_info.prefiller_idx, instance_info.request_id)
|
|
||||||
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
|
|
||||||
instance_info.prefiller_score)
|
|
||||||
|
|
||||||
# After streaming done, release tokens
|
# After streaming done, release tokens
|
||||||
proxy_state.release_decoder(instance_info.decoder_idx,
|
proxy_state.release_decoder(decoder_idx, decoder_score)
|
||||||
instance_info.decoder_score)
|
|
||||||
|
|
||||||
return StreamingResponse(generate_stream(),
|
return StreamingResponse(generate_stream(),
|
||||||
media_type="application/json")
|
media_type="application/json")
|
||||||
@@ -669,11 +542,33 @@ async def healthcheck():
|
|||||||
@app.post("/v1/metaserver")
|
@app.post("/v1/metaserver")
|
||||||
async def metaserver(request: Request):
|
async def metaserver(request: Request):
|
||||||
try:
|
try:
|
||||||
req_data = await request.json()
|
kv_transfer_params = await request.json()
|
||||||
request_id = req_data.pop("request_id", None)
|
|
||||||
if request_id in proxy_state.req_id_future:
|
request_id = kv_transfer_params["request_id"]
|
||||||
result_future = proxy_state.req_id_future[request_id]
|
assert request_id in proxy_state.req_data_dict
|
||||||
result_future.set_result(req_data)
|
req_data, request_length, api = proxy_state.req_data_dict[request_id]
|
||||||
|
request_id = get_origin_request_id(api, request_id)
|
||||||
|
req_data["kv_transfer_params"] = kv_transfer_params
|
||||||
|
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||||
|
logger.debug(
|
||||||
|
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Select prefiller
|
||||||
|
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||||
|
prefiller = proxy_state.prefillers[prefiller_idx]
|
||||||
|
logger.debug(f"Using prefill {prefiller.url=} {req_data=}")
|
||||||
|
# Send request to prefiller
|
||||||
|
response = await send_request_to_service(
|
||||||
|
prefiller.client,
|
||||||
|
prefiller_idx,
|
||||||
|
api,
|
||||||
|
req_data,
|
||||||
|
request_id,
|
||||||
|
max_retries=global_args.max_retries,
|
||||||
|
base_delay=global_args.retry_delay)
|
||||||
|
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Post metaserver failed with: {str(e)}")
|
logger.error(f"Post metaserver failed with: {str(e)}")
|
||||||
|
|
||||||
@@ -682,5 +577,4 @@ if __name__ == '__main__':
|
|||||||
global global_args
|
global global_args
|
||||||
global_args = parse_args()
|
global_args = parse_args()
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -96,7 +96,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
|||||||
parallel_config.data_parallel_size, num_head_replica, -1,
|
parallel_config.data_parallel_size, num_head_replica, -1,
|
||||||
alltoall_group_size
|
alltoall_group_size
|
||||||
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
|
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
|
||||||
group_ranks = group_ranks.view(-1, alltoall_group_size).unbind(0)
|
group_ranks = group_ranks.reshape(-1,
|
||||||
|
alltoall_group_size).unbind(0)
|
||||||
group_ranks = [x.tolist() for x in group_ranks]
|
group_ranks = [x.tolist() for x in group_ranks]
|
||||||
local_rank = get_world_group().local_rank
|
local_rank = get_world_group().local_rank
|
||||||
num = next(
|
num = next(
|
||||||
|
|||||||
Reference in New Issue
Block a user