diff --git a/python/sglang/srt/disaggregation/mini_lb.py b/python/sglang/srt/disaggregation/mini_lb.py index 1dab5c9b3..3b8422421 100644 --- a/python/sglang/srt/disaggregation/mini_lb.py +++ b/python/sglang/srt/disaggregation/mini_lb.py @@ -161,12 +161,24 @@ async def handle_generate_request(request_data: dict): parsed_url = urllib.parse.urlparse(prefill_server) hostname = parsed_url.hostname modified_request = request_data.copy() - modified_request.update( - { - "bootstrap_host": hostname, - "bootstrap_room": random.randint(0, 2**63 - 1), - } - ) + + batch_size = _get_request_batch_size(modified_request) + if batch_size is not None: + modified_request.update( + { + "bootstrap_host": [hostname] * batch_size, + "bootstrap_room": [ + _generate_bootstrap_room() for _ in range(batch_size) + ], + } + ) + else: + modified_request.update( + { + "bootstrap_host": hostname, + "bootstrap_room": _generate_bootstrap_room(), + } + ) if request_data.get("stream", False): return await load_balancer.generate_stream( @@ -178,6 +190,19 @@ async def handle_generate_request(request_data: dict): ) +def _generate_bootstrap_room(): + return random.randint(0, 2**63 - 1) + + +# We may utilize `GenerateReqInput`'s logic later +def _get_request_batch_size(request): + if (text := request.get("text")) is not None: + return None if isinstance(text, str) else len(text) + if (input_ids := request.get("input_ids")) is not None: + return None if isinstance(input_ids[0], int) else len(input_ids) + return None + + @app.get("/v1/models") async def get_models(): prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index dbb632bd2..e8590c950 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -96,8 +96,8 @@ class GenerateReqInput: return_hidden_states: bool = False # For disaggregated inference - bootstrap_host: Optional[str] = None - bootstrap_room: Optional[int] = None + bootstrap_host: Optional[Union[List[str], str]] = None + bootstrap_room: Optional[Union[List[int], int]] = None def normalize_batch_and_arguments(self): """ @@ -397,6 +397,12 @@ class GenerateReqInput: else None ), return_hidden_states=self.return_hidden_states, + bootstrap_host=( + self.bootstrap_host[i] if self.bootstrap_host is not None else None + ), + bootstrap_room=( + self.bootstrap_room[i] if self.bootstrap_room is not None else None + ), )