[PD] Fix server crash when using batch requests (#5531)
This commit is contained in:
@@ -161,12 +161,24 @@ async def handle_generate_request(request_data: dict):
|
|||||||
parsed_url = urllib.parse.urlparse(prefill_server)
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
||||||
hostname = parsed_url.hostname
|
hostname = parsed_url.hostname
|
||||||
modified_request = request_data.copy()
|
modified_request = request_data.copy()
|
||||||
modified_request.update(
|
|
||||||
{
|
batch_size = _get_request_batch_size(modified_request)
|
||||||
"bootstrap_host": hostname,
|
if batch_size is not None:
|
||||||
"bootstrap_room": random.randint(0, 2**63 - 1),
|
modified_request.update(
|
||||||
}
|
{
|
||||||
)
|
"bootstrap_host": [hostname] * batch_size,
|
||||||
|
"bootstrap_room": [
|
||||||
|
_generate_bootstrap_room() for _ in range(batch_size)
|
||||||
|
],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
modified_request.update(
|
||||||
|
{
|
||||||
|
"bootstrap_host": hostname,
|
||||||
|
"bootstrap_room": _generate_bootstrap_room(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
if request_data.get("stream", False):
|
if request_data.get("stream", False):
|
||||||
return await load_balancer.generate_stream(
|
return await load_balancer.generate_stream(
|
||||||
@@ -178,6 +190,19 @@ async def handle_generate_request(request_data: dict):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_bootstrap_room():
|
||||||
|
return random.randint(0, 2**63 - 1)
|
||||||
|
|
||||||
|
|
||||||
|
# We may utilize `GenerateReqInput`'s logic later
|
||||||
|
def _get_request_batch_size(request):
|
||||||
|
if (text := request.get("text")) is not None:
|
||||||
|
return None if isinstance(text, str) else len(text)
|
||||||
|
if (input_ids := request.get("input_ids")) is not None:
|
||||||
|
return None if isinstance(input_ids[0], int) else len(input_ids)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/models")
|
@app.get("/v1/models")
|
||||||
async def get_models():
|
async def get_models():
|
||||||
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
||||||
|
|||||||
@@ -96,8 +96,8 @@ class GenerateReqInput:
|
|||||||
return_hidden_states: bool = False
|
return_hidden_states: bool = False
|
||||||
|
|
||||||
# For disaggregated inference
|
# For disaggregated inference
|
||||||
bootstrap_host: Optional[str] = None
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
||||||
bootstrap_room: Optional[int] = None
|
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||||
|
|
||||||
def normalize_batch_and_arguments(self):
|
def normalize_batch_and_arguments(self):
|
||||||
"""
|
"""
|
||||||
@@ -397,6 +397,12 @@ class GenerateReqInput:
|
|||||||
else None
|
else None
|
||||||
),
|
),
|
||||||
return_hidden_states=self.return_hidden_states,
|
return_hidden_states=self.return_hidden_states,
|
||||||
|
bootstrap_host=(
|
||||||
|
self.bootstrap_host[i] if self.bootstrap_host is not None else None
|
||||||
|
),
|
||||||
|
bootstrap_room=(
|
||||||
|
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user