[bugfix_v0.11.0-dev] layerwise D first plan (#3907)
### What this PR does / why we need it? Refactored the layerwise code to send to the D node first, preventing P-node hangs due to communication timeouts when DP > 1. --------- Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com> Signed-off-by: liziyu <liziyu16@huawei.com> Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com> Co-authored-by: nwpu-zxr <zhouxuerong2@huawei.com> Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
@@ -88,18 +88,17 @@ import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import heapq
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
from typing import List
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from transformers import AutoTokenizer
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -107,7 +106,6 @@ logger = init_logger(__name__)
|
||||
# Add uvloop for faster event loop if available
|
||||
try:
|
||||
import uvloop
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except ImportError:
|
||||
pass
|
||||
@@ -154,6 +152,9 @@ class ProxyState:
|
||||
heapq.heapify(self.prefiller_heap)
|
||||
heapq.heapify(self.decoder_heap)
|
||||
self.req_id_future = {}
|
||||
self.req_data_dict = {}
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
global_args.tokenizer_dir)
|
||||
|
||||
def _update_prefiller_priority(self, server_idx: int):
|
||||
"""Update the priority of a prefiller server in the heap."""
|
||||
@@ -280,6 +281,10 @@ def parse_args():
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
||||
parser.add_argument("--tokenizer-dir",
|
||||
type=str,
|
||||
default="/mnt/weight/Qwen3-235B-A22B-W8A8",
|
||||
help="Maximum number of retries for HTTP requests")
|
||||
parser.add_argument("--max-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
@@ -356,17 +361,6 @@ async def send_request_to_service(client: httpx.AsyncClient,
|
||||
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
|
||||
prefiller_id)
|
||||
req_data = req_data.copy()
|
||||
req_data['kv_transfer_params'] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
"aborted_request": list(aborted_requests),
|
||||
"metaserver":
|
||||
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
|
||||
}
|
||||
req_data["stream"] = False
|
||||
req_data["max_tokens"] = 1
|
||||
if "stream_options" in req_data:
|
||||
@@ -458,59 +452,11 @@ def get_api_request_id(api, req_id):
|
||||
return "chatcmpl-" + req_id
|
||||
|
||||
|
||||
async def _handle_select_instance(api: str, req_data: Any,
|
||||
request_length: int):
|
||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||
logger.debug(
|
||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
||||
)
|
||||
request_id = await proxy_state.next_req_id()
|
||||
# Select prefiller
|
||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||
prefiller = proxy_state.prefillers[prefiller_idx]
|
||||
result_future = asyncio.Future() # type: ignore
|
||||
request_id_api = get_api_request_id(api, request_id)
|
||||
proxy_state.req_id_future[request_id_api] = result_future
|
||||
# Send request to prefiller
|
||||
asyncio.get_running_loop().create_task(
|
||||
send_request_to_service(prefiller.client,
|
||||
prefiller_idx,
|
||||
api,
|
||||
req_data,
|
||||
request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay))
|
||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||
|
||||
response = await result_future
|
||||
del proxy_state.req_id_future[request_id_api]
|
||||
req_data["kv_transfer_params"] = response
|
||||
|
||||
# Select decoder
|
||||
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
||||
logger.debug("Decoder score: %f", decoder_score)
|
||||
# Use the prefiller's kv_transfer_params to select decoder
|
||||
decoder_idx = proxy_state.select_decoder(decoder_score)
|
||||
decoder = proxy_state.decoders[decoder_idx]
|
||||
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
||||
return InstanceInfo(request_id=request_id,
|
||||
prefiller_idx=prefiller_idx,
|
||||
prefiller_score=prefiller_score,
|
||||
prefiller=prefiller,
|
||||
decoder=decoder,
|
||||
decoder_idx=decoder_idx,
|
||||
decoder_score=decoder_score)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstanceInfo:
|
||||
request_id: str
|
||||
prefiller_idx: int
|
||||
prefiller_score: float
|
||||
prefiller: ServerState
|
||||
decoder_idx: int
|
||||
decoder_score: float
|
||||
decoder: ServerState
|
||||
def get_origin_request_id(api, req_id):
|
||||
if api == "/completions":
|
||||
return req_id.replace("cmpl-", "").replace("-0", "")
|
||||
elif api == "/chat/completions":
|
||||
return req_id.replace("chatcmpl-", "")
|
||||
|
||||
|
||||
async def _handle_completions(api: str, request: Request):
|
||||
@@ -518,120 +464,47 @@ async def _handle_completions(api: str, request: Request):
|
||||
req_data = await request.json()
|
||||
req_body = await request.body()
|
||||
request_length = len(req_body)
|
||||
instance_info = await _handle_select_instance(api, req_data,
|
||||
request_length)
|
||||
stream_flag = bool(req_data.get("stream", False))
|
||||
chat_flag = "messages" in req_data
|
||||
|
||||
if "prompt" in req_data:
|
||||
origin_prompt = req_data["prompt"]
|
||||
elif chat_flag:
|
||||
messages = req_data["messages"]
|
||||
origin_prompt = messages[0].get("content", "")
|
||||
else:
|
||||
origin_prompt = ""
|
||||
# refer to vLLM sampling_params: max_token default value
|
||||
origin_max_tokens = req_data.get("max_tokens", 16)
|
||||
request_id = await proxy_state.next_req_id()
|
||||
request_id_api = get_api_request_id(api, request_id)
|
||||
proxy_state.req_data_dict[request_id_api] = (req_data, request_length,
|
||||
api)
|
||||
req_data['kv_transfer_params'] = {
|
||||
"do_remote_decode":
|
||||
False,
|
||||
"do_remote_prefill":
|
||||
True,
|
||||
"metaserver":
|
||||
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
|
||||
}
|
||||
# Select decoder
|
||||
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
||||
logger.debug("Decoder score: %f", decoder_score)
|
||||
# Use the prefiller's kv_transfer_params to select decoder
|
||||
decoder_idx = proxy_state.select_decoder(decoder_score)
|
||||
decoder = proxy_state.decoders[decoder_idx]
|
||||
# logger.debug("Using %s %s", prefiller.url, decoder.url)
|
||||
# Stream response from decoder
|
||||
released_kv = False
|
||||
|
||||
async def generate_stream():
|
||||
nonlocal instance_info
|
||||
generated_token = ""
|
||||
released_kv = False
|
||||
retry_count = 0
|
||||
retry = True
|
||||
completion_tokens = 0
|
||||
nonlocal released_kv
|
||||
# Only one await per chunk, minimal logic in loop
|
||||
try:
|
||||
while retry:
|
||||
retry = False
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
instance_info.decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=instance_info.request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay):
|
||||
if not released_kv and chunk:
|
||||
proxy_state.release_prefiller_kv(
|
||||
instance_info.prefiller_idx,
|
||||
instance_info.prefiller_score)
|
||||
released_kv = True
|
||||
try:
|
||||
chunk_str = chunk.decode("utf-8").strip()
|
||||
except UnicodeDecodeError:
|
||||
logger.debug(
|
||||
f"Skipping chunk: {chunk}")
|
||||
yield chunk
|
||||
continue
|
||||
if not chunk_str:
|
||||
continue
|
||||
if chunk_str.startswith("data: "):
|
||||
chunk_str = chunk_str[len("data: "):]
|
||||
try:
|
||||
chunk_json = json.loads(chunk_str)
|
||||
except json.JSONDecodeError:
|
||||
# if chunk is [done], skip it.
|
||||
logger.debug(
|
||||
f"Skipping chunk: {chunk_str}")
|
||||
yield chunk
|
||||
continue
|
||||
choices = chunk_json.get("choices", [])
|
||||
if not choices:
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
choice = choices[0]
|
||||
delta = choice.get("delta") or {}
|
||||
message = choice.get("message") or {}
|
||||
content = (
|
||||
delta.get("content")
|
||||
or message.get("content")
|
||||
or choice.get("text")
|
||||
or ""
|
||||
)
|
||||
generated_token += content
|
||||
|
||||
stop_reason = choice.get(
|
||||
"stop_reason")
|
||||
usage = chunk_json.get("usage", {})
|
||||
completion_tokens = (completion_tokens + 1) if stream_flag else \
|
||||
(completion_tokens + usage.get("completion_tokens"))
|
||||
if stop_reason == "recomputed":
|
||||
retry = True
|
||||
retry_count += 1
|
||||
if chat_flag:
|
||||
messages[0][
|
||||
"content"] = origin_prompt + generated_token
|
||||
else:
|
||||
req_data[
|
||||
"prompt"] = origin_prompt + generated_token
|
||||
req_data[
|
||||
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
||||
tmp_request_length = len(
|
||||
json.dumps(req_data).encode("utf-8"))
|
||||
instance_info = await _handle_select_instance(
|
||||
api, req_data, tmp_request_length)
|
||||
break
|
||||
if retry_count > 0 and not stream_flag:
|
||||
if chat_flag:
|
||||
choices[0]["message"][
|
||||
"content"] = generated_token
|
||||
else:
|
||||
choices[0]["text"] = generated_token
|
||||
chunk = json.dumps(chunk_json).encode("utf-8")
|
||||
yield chunk
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
||||
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
||||
)
|
||||
proxy_state.abort_prefiller_request(
|
||||
instance_info.prefiller_idx, instance_info.request_id)
|
||||
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
|
||||
instance_info.prefiller_score)
|
||||
|
||||
# After streaming done, release tokens
|
||||
proxy_state.release_decoder(instance_info.decoder_idx,
|
||||
instance_info.decoder_score)
|
||||
proxy_state.release_decoder(decoder_idx, decoder_score)
|
||||
|
||||
return StreamingResponse(generate_stream(),
|
||||
media_type="application/json")
|
||||
@@ -669,11 +542,33 @@ async def healthcheck():
|
||||
@app.post("/v1/metaserver")
|
||||
async def metaserver(request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
request_id = req_data.pop("request_id", None)
|
||||
if request_id in proxy_state.req_id_future:
|
||||
result_future = proxy_state.req_id_future[request_id]
|
||||
result_future.set_result(req_data)
|
||||
kv_transfer_params = await request.json()
|
||||
|
||||
request_id = kv_transfer_params["request_id"]
|
||||
assert request_id in proxy_state.req_data_dict
|
||||
req_data, request_length, api = proxy_state.req_data_dict[request_id]
|
||||
request_id = get_origin_request_id(api, request_id)
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||
logger.debug(
|
||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
||||
)
|
||||
|
||||
# Select prefiller
|
||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||
prefiller = proxy_state.prefillers[prefiller_idx]
|
||||
logger.debug(f"Using prefill {prefiller.url=} {req_data=}")
|
||||
# Send request to prefiller
|
||||
response = await send_request_to_service(
|
||||
prefiller.client,
|
||||
prefiller_idx,
|
||||
api,
|
||||
req_data,
|
||||
request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay)
|
||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Post metaserver failed with: {str(e)}")
|
||||
|
||||
@@ -682,5 +577,4 @@ if __name__ == '__main__':
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
|
||||
@@ -1136,4 +1136,4 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -89,7 +89,8 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
parallel_config.data_parallel_size, num_head_replica, -1,
|
||||
alltoall_group_size
|
||||
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
|
||||
group_ranks = group_ranks.view(-1, alltoall_group_size).unbind(0)
|
||||
group_ranks = group_ranks.reshape(-1,
|
||||
alltoall_group_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
local_rank = get_world_group().local_rank
|
||||
num = next(
|
||||
|
||||
Reference in New Issue
Block a user