feat: add direct routing strategy to DP worker (#6884)
This commit is contained in:
@@ -167,11 +167,22 @@ class Engine(EngineBase):
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None,
|
||||
bootstrap_port: Optional[Union[List[int], int]] = None,
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> Union[Dict, Iterator[Dict]]:
|
||||
"""
|
||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
||||
Please refer to `GenerateReqInput` for the documentation.
|
||||
"""
|
||||
if self.server_args.enable_dp_attention:
|
||||
if data_parallel_rank is None:
|
||||
logger.info("data_parallel_rank not provided, using default dispatch")
|
||||
elif data_parallel_rank < 0:
|
||||
raise ValueError("data_parallel_rank must be non-negative")
|
||||
elif data_parallel_rank >= self.server_args.dp_size:
|
||||
raise ValueError(
|
||||
f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}"
|
||||
)
|
||||
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
@@ -188,6 +199,7 @@ class Engine(EngineBase):
|
||||
bootstrap_host=bootstrap_host,
|
||||
bootstrap_port=bootstrap_port,
|
||||
bootstrap_room=bootstrap_room,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||
@@ -237,11 +249,24 @@ class Engine(EngineBase):
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None,
|
||||
bootstrap_port: Optional[Union[List[int], int]] = None,
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> Union[Dict, AsyncIterator[Dict]]:
|
||||
"""
|
||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
||||
Please refer to `GenerateReqInput` for the documentation.
|
||||
"""
|
||||
|
||||
if self.server_args.enable_dp_attention:
|
||||
if data_parallel_rank is None:
|
||||
logger.info("data_parallel_rank not provided, using default dispatch")
|
||||
elif data_parallel_rank < 0:
|
||||
raise ValueError("data_parallel_rank must be non-negative")
|
||||
elif data_parallel_rank >= self.server_args.dp_size:
|
||||
raise ValueError(
|
||||
f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]"
|
||||
)
|
||||
|
||||
logger.info(f"data_parallel_rank: {data_parallel_rank}")
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
@@ -257,6 +282,7 @@ class Engine(EngineBase):
|
||||
bootstrap_host=bootstrap_host,
|
||||
bootstrap_port=bootstrap_port,
|
||||
bootstrap_room=bootstrap_room,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user