feat: add direct routing strategy to DP worker (#6884)
This commit is contained in:
@@ -109,10 +109,12 @@ class CommonKVReceiver(BaseKVReceiver):
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
):
|
||||
self.bootstrap_room = bootstrap_room
|
||||
self.bootstrap_addr = bootstrap_addr
|
||||
self.kv_mgr = mgr
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
|
||||
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||
self.prefill_tp_size, self.prefill_dp_size = (
|
||||
@@ -180,7 +182,11 @@ class CommonKVReceiver(BaseKVReceiver):
|
||||
self.target_tp_rank = self.target_tp_ranks[0]
|
||||
self.required_dst_info_num = 1
|
||||
|
||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||
if self.data_parallel_rank is not None:
|
||||
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
||||
self.target_dp_group = self.data_parallel_rank
|
||||
else:
|
||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||
|
||||
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||
bootstrap_key = (
|
||||
|
||||
@@ -156,6 +156,7 @@ class DecodePreallocQueue:
|
||||
mgr=self.kv_manager,
|
||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||
bootstrap_room=req.bootstrap_room,
|
||||
data_parallel_rank=req.data_parallel_rank,
|
||||
)
|
||||
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ class FakeKVReceiver(BaseKVReceiver):
|
||||
mgr: BaseKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
):
|
||||
self.has_init = False
|
||||
|
||||
|
||||
@@ -765,6 +765,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
mgr: MooncakeKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
):
|
||||
self.bootstrap_room = bootstrap_room
|
||||
self.bootstrap_addr = bootstrap_addr
|
||||
@@ -772,6 +773,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.session_id = self.kv_mgr.get_session_id()
|
||||
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
||||
self.conclude_state = None
|
||||
self.data_parallel_rank = data_parallel_rank
|
||||
|
||||
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||
self.prefill_tp_size, self.prefill_dp_size = (
|
||||
@@ -845,7 +847,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.target_tp_rank = self.target_tp_ranks[0]
|
||||
self.required_dst_info_num = 1
|
||||
|
||||
self.target_dp_group = self.bootstrap_room % self.prefill_dp_size
|
||||
if self.data_parallel_rank is not None:
|
||||
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
||||
self.target_dp_group = self.data_parallel_rank
|
||||
else:
|
||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||
|
||||
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||
bootstrap_key = (
|
||||
|
||||
@@ -407,9 +407,10 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
mgr: NixlKVManager,
|
||||
bootstrap_addr: str,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
):
|
||||
self.started_transfer = False
|
||||
super().__init__(mgr, bootstrap_addr, bootstrap_room)
|
||||
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
|
||||
@@ -23,6 +23,12 @@ class EngineBase(ABC):
|
||||
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
|
||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||
return_hidden_states: Optional[bool] = None,
|
||||
stream: Optional[bool] = None,
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None,
|
||||
bootstrap_port: Optional[Union[List[int], int]] = None,
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> Union[Dict, Iterator[Dict]]:
|
||||
"""Generate outputs based on given inputs."""
|
||||
pass
|
||||
|
||||
@@ -167,11 +167,22 @@ class Engine(EngineBase):
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None,
|
||||
bootstrap_port: Optional[Union[List[int], int]] = None,
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> Union[Dict, Iterator[Dict]]:
|
||||
"""
|
||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
||||
Please refer to `GenerateReqInput` for the documentation.
|
||||
"""
|
||||
if self.server_args.enable_dp_attention:
|
||||
if data_parallel_rank is None:
|
||||
logger.info("data_parallel_rank not provided, using default dispatch")
|
||||
elif data_parallel_rank < 0:
|
||||
raise ValueError("data_parallel_rank must be non-negative")
|
||||
elif data_parallel_rank >= self.server_args.dp_size:
|
||||
raise ValueError(
|
||||
f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}"
|
||||
)
|
||||
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
@@ -188,6 +199,7 @@ class Engine(EngineBase):
|
||||
bootstrap_host=bootstrap_host,
|
||||
bootstrap_port=bootstrap_port,
|
||||
bootstrap_room=bootstrap_room,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||
@@ -237,11 +249,24 @@ class Engine(EngineBase):
|
||||
bootstrap_host: Optional[Union[List[str], str]] = None,
|
||||
bootstrap_port: Optional[Union[List[int], int]] = None,
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
) -> Union[Dict, AsyncIterator[Dict]]:
|
||||
"""
|
||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
||||
Please refer to `GenerateReqInput` for the documentation.
|
||||
"""
|
||||
|
||||
if self.server_args.enable_dp_attention:
|
||||
if data_parallel_rank is None:
|
||||
logger.info("data_parallel_rank not provided, using default dispatch")
|
||||
elif data_parallel_rank < 0:
|
||||
raise ValueError("data_parallel_rank must be non-negative")
|
||||
elif data_parallel_rank >= self.server_args.dp_size:
|
||||
raise ValueError(
|
||||
f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]"
|
||||
)
|
||||
|
||||
logger.info(f"data_parallel_rank: {data_parallel_rank}")
|
||||
obj = GenerateReqInput(
|
||||
text=prompt,
|
||||
input_ids=input_ids,
|
||||
@@ -257,6 +282,7 @@ class Engine(EngineBase):
|
||||
bootstrap_host=bootstrap_host,
|
||||
bootstrap_port=bootstrap_port,
|
||||
bootstrap_room=bootstrap_room,
|
||||
data_parallel_rank=data_parallel_rank,
|
||||
)
|
||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||
|
||||
|
||||
@@ -248,12 +248,20 @@ class DataParallelController:
|
||||
|
||||
def round_robin_scheduler(self, req: Req):
|
||||
if self.server_args.disaggregation_mode == "null":
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||
self.workers
|
||||
)
|
||||
if req.data_parallel_rank is not None:
|
||||
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
||||
self.workers[req.data_parallel_rank].send_pyobj(req)
|
||||
else:
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||
self.workers
|
||||
)
|
||||
else:
|
||||
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
||||
if req.data_parallel_rank is not None:
|
||||
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
||||
self.workers[req.data_parallel_rank].send_pyobj(req)
|
||||
else:
|
||||
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
||||
|
||||
def shortest_queue_scheduler(self, input_requests):
|
||||
raise NotImplementedError()
|
||||
|
||||
@@ -106,6 +106,9 @@ class GenerateReqInput:
|
||||
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
||||
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||
|
||||
# For data parallel rank routing
|
||||
data_parallel_rank: Optional[int] = None
|
||||
|
||||
def contains_mm_input(self) -> bool:
|
||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||
|
||||
@@ -417,6 +420,9 @@ class GenerateReqInput:
|
||||
bootstrap_room=(
|
||||
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
||||
),
|
||||
data_parallel_rank=(
|
||||
self.data_parallel_rank if self.data_parallel_rank is not None else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -464,6 +470,9 @@ class TokenizedGenerateReqInput:
|
||||
bootstrap_port: Optional[int] = None
|
||||
bootstrap_room: Optional[int] = None
|
||||
|
||||
# For data parallel rank routing
|
||||
data_parallel_rank: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddingReqInput:
|
||||
|
||||
@@ -451,6 +451,7 @@ class Req:
|
||||
bootstrap_host: Optional[str] = None,
|
||||
bootstrap_port: Optional[int] = None,
|
||||
bootstrap_room: Optional[int] = None,
|
||||
data_parallel_rank: Optional[int] = None,
|
||||
):
|
||||
# Input and output info
|
||||
self.rid = rid
|
||||
@@ -605,6 +606,9 @@ class Req:
|
||||
self.bootstrap_room: Optional[int] = bootstrap_room
|
||||
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
||||
|
||||
# For data parallel rank routing
|
||||
self.data_parallel_rank: Optional[int] = data_parallel_rank
|
||||
|
||||
# the start index of the sent kv cache
|
||||
# We want to send it chunk by chunk for chunked prefill.
|
||||
# After every chunk forward, we do the following:
|
||||
|
||||
@@ -949,6 +949,7 @@ class Scheduler(
|
||||
bootstrap_host=recv_req.bootstrap_host,
|
||||
bootstrap_port=recv_req.bootstrap_port,
|
||||
bootstrap_room=recv_req.bootstrap_room,
|
||||
data_parallel_rank=recv_req.data_parallel_rank,
|
||||
)
|
||||
req.tokenizer = self.tokenizer
|
||||
|
||||
|
||||
@@ -570,6 +570,7 @@ class TokenizerManager:
|
||||
session_params=session_params,
|
||||
custom_logit_processor=obj.custom_logit_processor,
|
||||
return_hidden_states=obj.return_hidden_states,
|
||||
data_parallel_rank=obj.data_parallel_rank,
|
||||
)
|
||||
elif isinstance(obj, EmbeddingReqInput):
|
||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||
|
||||
Reference in New Issue
Block a user