[PD]: Support Muti Prefill in one node (#5704)
Co-authored-by: shuaills <shishuaiuoe@gmail.com>
This commit is contained in:
@@ -137,7 +137,7 @@ class DecodePreallocQueue:
|
||||
kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
|
||||
kv_receiver = kv_receiver_class(
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
)
|
||||
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
||||
|
||||
@@ -6,6 +6,7 @@ import asyncio
|
||||
import random
|
||||
import urllib
|
||||
from itertools import chain
|
||||
from typing import List
|
||||
|
||||
import aiohttp
|
||||
import orjson
|
||||
@@ -14,13 +15,22 @@ from fastapi import FastAPI, HTTPException
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
|
||||
class PrefillConfig:
|
||||
def __init__(self, url: str, bootstrap_port: int):
|
||||
self.url = url
|
||||
self.bootstrap_port = bootstrap_port
|
||||
|
||||
|
||||
class MiniLoadBalancer:
|
||||
def __init__(self, prefill_servers, decode_servers):
|
||||
self.prefill_servers = prefill_servers
|
||||
def __init__(self, prefill_configs: List[PrefillConfig], decode_servers: List[str]):
|
||||
self.prefill_configs = prefill_configs
|
||||
self.prefill_servers = [p.url for p in prefill_configs]
|
||||
self.decode_servers = decode_servers
|
||||
|
||||
def select_pair(self):
|
||||
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
|
||||
prefill_config = random.choice(self.prefill_configs)
|
||||
decode_server = random.choice(self.decode_servers)
|
||||
return prefill_config.url, prefill_config.bootstrap_port, decode_server
|
||||
|
||||
async def generate(
|
||||
self, modified_request, prefill_server, decode_server, endpoint
|
||||
@@ -160,7 +170,7 @@ async def get_model_info():
|
||||
|
||||
@app.post("/generate")
|
||||
async def handle_generate_request(request_data: dict):
|
||||
prefill_server, decode_server = load_balancer.select_pair()
|
||||
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
@@ -172,6 +182,7 @@ async def handle_generate_request(request_data: dict):
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": [hostname] * batch_size,
|
||||
"bootstrap_port": [bootstrap_port] * batch_size,
|
||||
"bootstrap_room": [
|
||||
_generate_bootstrap_room() for _ in range(batch_size)
|
||||
],
|
||||
@@ -181,6 +192,7 @@ async def handle_generate_request(request_data: dict):
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": _generate_bootstrap_room(),
|
||||
}
|
||||
)
|
||||
@@ -197,7 +209,7 @@ async def handle_generate_request(request_data: dict):
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def handle_completion_request(request_data: dict):
|
||||
prefill_server, decode_server = load_balancer.select_pair()
|
||||
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
||||
|
||||
# Parse and transform prefill_server for bootstrap data
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
@@ -206,6 +218,7 @@ async def handle_completion_request(request_data: dict):
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_port": bootstrap_port,
|
||||
"bootstrap_room": random.randint(0, 2**63 - 1),
|
||||
}
|
||||
)
|
||||
@@ -255,9 +268,9 @@ async def get_models():
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def run(prefill_addrs, decode_addrs, host, port):
|
||||
def run(prefill_configs, decode_addrs, host, port):
|
||||
global load_balancer
|
||||
load_balancer = MiniLoadBalancer(prefill_addrs, decode_addrs)
|
||||
load_balancer = MiniLoadBalancer(prefill_configs, decode_addrs)
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
|
||||
@@ -268,6 +281,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--prefill", required=True, help="Comma-separated URLs for prefill servers"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefill-bootstrap-ports",
|
||||
help="Comma-separated bootstrap ports for prefill servers",
|
||||
default="8998",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decode", required=True, help="Comma-separated URLs for decode servers"
|
||||
)
|
||||
@@ -278,4 +296,23 @@ if __name__ == "__main__":
|
||||
"--port", type=int, default=8000, help="Port to bind the server (default: 8000)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
run(args.prefill.split(","), args.decode.split(","), args.host, args.port)
|
||||
|
||||
prefill_urls = args.prefill.split(",")
|
||||
bootstrap_ports = [int(p) for p in args.prefill_bootstrap_ports.split(",")]
|
||||
|
||||
if len(bootstrap_ports) == 1:
|
||||
bootstrap_ports = bootstrap_ports * len(prefill_urls)
|
||||
else:
|
||||
if len(bootstrap_ports) != len(prefill_urls):
|
||||
raise ValueError(
|
||||
"Number of prefill URLs must match number of bootstrap ports"
|
||||
)
|
||||
exit(1)
|
||||
|
||||
prefill_configs = []
|
||||
for url, port in zip(prefill_urls, bootstrap_ports):
|
||||
prefill_configs.append(PrefillConfig(url, port))
|
||||
|
||||
decode_addrs = args.decode.split(",")
|
||||
|
||||
run(prefill_configs, decode_addrs, args.host, args.port)
|
||||
|
||||
@@ -97,6 +97,7 @@ class GenerateReqInput:
|
||||
|
||||
# For disaggregated inference
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
bootstrap_port: Optional[Union[List[int], int]] = None
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
@@ -400,6 +401,9 @@ class GenerateReqInput:
|
||||
bootstrap_host=(
|
||||
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
||||
),
|
||||
bootstrap_port=(
|
||||
self.bootstrap_port[i] if self.bootstrap_port is not None else None
|
||||
),
|
||||
bootstrap_room=(
|
||||
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
||||
),
|
||||
@@ -447,6 +451,7 @@ class TokenizedGenerateReqInput:
|
||||
|
||||
# For disaggregated inference
|
||||
bootstrap_host: Optional[str] = None
|
||||
bootstrap_port: Optional[int] = None
|
||||
bootstrap_room: Optional[int] = None
|
||||
|
||||
|
||||
|
||||
@@ -391,6 +391,7 @@ class Req:
|
||||
return_hidden_states: bool = False,
|
||||
eos_token_ids: Optional[Set[int]] = None,
|
||||
bootstrap_host: Optional[str] = None,
|
||||
bootstrap_port: Optional[int] = None,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
):
|
||||
# Input and output info
|
||||
@@ -526,6 +527,7 @@ class Req:
|
||||
|
||||
# For disaggregation
|
||||
self.bootstrap_host: str = bootstrap_host
|
||||
self.bootstrap_port: Optional[int] = bootstrap_port
|
||||
self.bootstrap_room: Optional[int] = bootstrap_room
|
||||
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
||||
|
||||
|
||||
@@ -791,6 +791,7 @@ class Scheduler(
|
||||
return_hidden_states=recv_req.return_hidden_states,
|
||||
eos_token_ids=self.model_config.hf_eos_token_id,
|
||||
bootstrap_host=recv_req.bootstrap_host,
|
||||
bootstrap_port=recv_req.bootstrap_port,
|
||||
bootstrap_room=recv_req.bootstrap_room,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
@@ -498,6 +498,7 @@ class TokenizerManager:
|
||||
token_ids_logprob,
|
||||
obj.stream,
|
||||
bootstrap_host=obj.bootstrap_host,
|
||||
bootstrap_port=obj.bootstrap_port,
|
||||
bootstrap_room=obj.bootstrap_room,
|
||||
lora_path=obj.lora_path,
|
||||
input_embeds=input_embeds,
|
||||
|
||||
Reference in New Issue
Block a user