[PD] Support completion endpoint (#6729)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -274,8 +274,7 @@ async def handle_generate_request(request_data: dict):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/chat/completions")
|
async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
||||||
async def handle_completion_request(request_data: dict):
|
|
||||||
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
prefill_server, bootstrap_port, decode_server = load_balancer.select_pair()
|
||||||
|
|
||||||
# Parse and transform prefill_server for bootstrap data
|
# Parse and transform prefill_server for bootstrap data
|
||||||
@@ -286,7 +285,7 @@ async def handle_completion_request(request_data: dict):
|
|||||||
{
|
{
|
||||||
"bootstrap_host": hostname,
|
"bootstrap_host": hostname,
|
||||||
"bootstrap_port": bootstrap_port,
|
"bootstrap_port": bootstrap_port,
|
||||||
"bootstrap_room": random.randint(0, 2**63 - 1),
|
"bootstrap_room": _generate_bootstrap_room(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -295,17 +294,27 @@ async def handle_completion_request(request_data: dict):
|
|||||||
modified_request,
|
modified_request,
|
||||||
prefill_server,
|
prefill_server,
|
||||||
decode_server,
|
decode_server,
|
||||||
endpoint="v1/chat/completions",
|
endpoint=endpoint_name,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return await load_balancer.generate(
|
return await load_balancer.generate(
|
||||||
modified_request,
|
modified_request,
|
||||||
prefill_server,
|
prefill_server,
|
||||||
decode_server,
|
decode_server,
|
||||||
endpoint="v1/chat/completions",
|
endpoint=endpoint_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/chat/completions")
|
||||||
|
async def handle_chat_completion_request(request_data: dict):
|
||||||
|
return await _forward_to_backend(request_data, "v1/chat/completions")
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/completions")
|
||||||
|
async def handle_completion_request(request_data: dict):
|
||||||
|
return await _forward_to_backend(request_data, "v1/completions")
|
||||||
|
|
||||||
|
|
||||||
def _generate_bootstrap_room():
|
def _generate_bootstrap_room():
|
||||||
return random.randint(0, 2**63 - 1)
|
return random.randint(0, 2**63 - 1)
|
||||||
|
|
||||||
|
|||||||
@@ -604,6 +604,9 @@ def v1_generate_request(
|
|||||||
stream=all_requests[0].stream,
|
stream=all_requests[0].stream,
|
||||||
rid=request_ids,
|
rid=request_ids,
|
||||||
lora_path=lora_paths,
|
lora_path=lora_paths,
|
||||||
|
bootstrap_host=all_requests[0].bootstrap_host,
|
||||||
|
bootstrap_port=all_requests[0].bootstrap_port,
|
||||||
|
bootstrap_room=all_requests[0].bootstrap_room,
|
||||||
)
|
)
|
||||||
|
|
||||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||||
|
|||||||
@@ -183,12 +183,17 @@ class CompletionRequest(BaseModel):
|
|||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||||
session_params: Optional[Dict] = None
|
session_params: Optional[Dict] = None
|
||||||
|
|
||||||
|
# For PD disaggregation
|
||||||
|
bootstrap_host: Optional[str] = None
|
||||||
|
bootstrap_port: Optional[int] = None
|
||||||
|
bootstrap_room: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
class CompletionResponseChoice(BaseModel):
|
class CompletionResponseChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
text: str
|
text: str
|
||||||
logprobs: Optional[LogProbs] = None
|
logprobs: Optional[LogProbs] = None
|
||||||
finish_reason: Literal["stop", "length", "content_filter"]
|
finish_reason: Literal["stop", "length", "content_filter", "abort"]
|
||||||
matched_stop: Union[None, int, str] = None
|
matched_stop: Union[None, int, str] = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user