feat: add direct routing strategy to DP worker (#6884)

This commit is contained in:
ishandhanani
2025-06-09 11:44:05 -07:00
committed by GitHub
parent 3465d7ae78
commit f1569876d5
12 changed files with 78 additions and 8 deletions

View File

@@ -167,11 +167,22 @@ class Engine(EngineBase):
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
) -> Union[Dict, Iterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
if self.server_args.enable_dp_attention:
if data_parallel_rank is None:
logger.info("data_parallel_rank not provided, using default dispatch")
elif data_parallel_rank < 0:
raise ValueError("data_parallel_rank must be non-negative")
elif data_parallel_rank >= self.server_args.dp_size:
raise ValueError(
f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}"
)
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
@@ -188,6 +199,7 @@ class Engine(EngineBase):
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
)
loop = asyncio.get_event_loop()
generator = self.tokenizer_manager.generate_request(obj, None)
@@ -237,11 +249,24 @@ class Engine(EngineBase):
bootstrap_host: Optional[Union[List[str], str]] = None,
bootstrap_port: Optional[Union[List[int], int]] = None,
bootstrap_room: Optional[Union[List[int], int]] = None,
data_parallel_rank: Optional[int] = None,
) -> Union[Dict, AsyncIterator[Dict]]:
"""
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
Please refer to `GenerateReqInput` for the documentation.
"""
if self.server_args.enable_dp_attention:
if data_parallel_rank is None:
logger.info("data_parallel_rank not provided, using default dispatch")
elif data_parallel_rank < 0:
raise ValueError("data_parallel_rank must be non-negative")
elif data_parallel_rank >= self.server_args.dp_size:
raise ValueError(
f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]"
)
logger.info(f"data_parallel_rank: {data_parallel_rank}")
obj = GenerateReqInput(
text=prompt,
input_ids=input_ids,
@@ -257,6 +282,7 @@ class Engine(EngineBase):
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
data_parallel_rank=data_parallel_rank,
)
generator = self.tokenizer_manager.generate_request(obj, None)