[PD] Fix server crash when using batch requests (#5531)
This commit is contained in:
@@ -161,12 +161,24 @@ async def handle_generate_request(request_data: dict):
|
||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||
hostname = parsed_url.hostname
|
||||
modified_request = request_data.copy()
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_room": random.randint(0, 2**63 - 1),
|
||||
}
|
||||
)
|
||||
|
||||
batch_size = _get_request_batch_size(modified_request)
|
||||
if batch_size is not None:
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": [hostname] * batch_size,
|
||||
"bootstrap_room": [
|
||||
_generate_bootstrap_room() for _ in range(batch_size)
|
||||
],
|
||||
}
|
||||
)
|
||||
else:
|
||||
modified_request.update(
|
||||
{
|
||||
"bootstrap_host": hostname,
|
||||
"bootstrap_room": _generate_bootstrap_room(),
|
||||
}
|
||||
)
|
||||
|
||||
if request_data.get("stream", False):
|
||||
return await load_balancer.generate_stream(
|
||||
@@ -178,6 +190,19 @@ async def handle_generate_request(request_data: dict):
|
||||
)
|
||||
|
||||
|
||||
def _generate_bootstrap_room():
|
||||
return random.randint(0, 2**63 - 1)
|
||||
|
||||
|
||||
# We may utilize `GenerateReqInput`'s logic later
|
||||
def _get_request_batch_size(request):
|
||||
if (text := request.get("text")) is not None:
|
||||
return None if isinstance(text, str) else len(text)
|
||||
if (input_ids := request.get("input_ids")) is not None:
|
||||
return None if isinstance(input_ids[0], int) else len(input_ids)
|
||||
return None
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def get_models():
|
||||
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
||||
|
||||
@@ -96,8 +96,8 @@ class GenerateReqInput:
|
||||
return_hidden_states: bool = False
|
||||
|
||||
# For disaggregated inference
|
||||
bootstrap_host: Optional[str] = None
|
||||
bootstrap_room: Optional[int] = None
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
"""
|
||||
@@ -397,6 +397,12 @@ class GenerateReqInput:
|
||||
else None
|
||||
),
|
||||
return_hidden_states=self.return_hidden_states,
|
||||
bootstrap_host=(
|
||||
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
||||
),
|
||||
bootstrap_room=(
|
||||
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user