feat: add direct routing strategy to DP worker (#6884)
This commit is contained in:
@@ -109,10 +109,12 @@ class CommonKVReceiver(BaseKVReceiver):
|
|||||||
mgr: BaseKVManager,
|
mgr: BaseKVManager,
|
||||||
bootstrap_addr: str,
|
bootstrap_addr: str,
|
||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.bootstrap_room = bootstrap_room
|
self.bootstrap_room = bootstrap_room
|
||||||
self.bootstrap_addr = bootstrap_addr
|
self.bootstrap_addr = bootstrap_addr
|
||||||
self.kv_mgr = mgr
|
self.kv_mgr = mgr
|
||||||
|
self.data_parallel_rank = data_parallel_rank
|
||||||
|
|
||||||
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||||
self.prefill_tp_size, self.prefill_dp_size = (
|
self.prefill_tp_size, self.prefill_dp_size = (
|
||||||
@@ -180,6 +182,10 @@ class CommonKVReceiver(BaseKVReceiver):
|
|||||||
self.target_tp_rank = self.target_tp_ranks[0]
|
self.target_tp_rank = self.target_tp_ranks[0]
|
||||||
self.required_dst_info_num = 1
|
self.required_dst_info_num = 1
|
||||||
|
|
||||||
|
if self.data_parallel_rank is not None:
|
||||||
|
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
||||||
|
self.target_dp_group = self.data_parallel_rank
|
||||||
|
else:
|
||||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||||
|
|
||||||
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||||
|
|||||||
@@ -156,6 +156,7 @@ class DecodePreallocQueue:
|
|||||||
mgr=self.kv_manager,
|
mgr=self.kv_manager,
|
||||||
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
||||||
bootstrap_room=req.bootstrap_room,
|
bootstrap_room=req.bootstrap_room,
|
||||||
|
data_parallel_rank=req.data_parallel_rank,
|
||||||
)
|
)
|
||||||
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
self.queue.append(DecodeRequest(req=req, kv_receiver=kv_receiver))
|
||||||
|
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class FakeKVReceiver(BaseKVReceiver):
|
|||||||
mgr: BaseKVManager,
|
mgr: BaseKVManager,
|
||||||
bootstrap_addr: str,
|
bootstrap_addr: str,
|
||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.has_init = False
|
self.has_init = False
|
||||||
|
|
||||||
|
|||||||
@@ -765,6 +765,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
mgr: MooncakeKVManager,
|
mgr: MooncakeKVManager,
|
||||||
bootstrap_addr: str,
|
bootstrap_addr: str,
|
||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.bootstrap_room = bootstrap_room
|
self.bootstrap_room = bootstrap_room
|
||||||
self.bootstrap_addr = bootstrap_addr
|
self.bootstrap_addr = bootstrap_addr
|
||||||
@@ -772,6 +773,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
self.session_id = self.kv_mgr.get_session_id()
|
self.session_id = self.kv_mgr.get_session_id()
|
||||||
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
||||||
self.conclude_state = None
|
self.conclude_state = None
|
||||||
|
self.data_parallel_rank = data_parallel_rank
|
||||||
|
|
||||||
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||||
self.prefill_tp_size, self.prefill_dp_size = (
|
self.prefill_tp_size, self.prefill_dp_size = (
|
||||||
@@ -845,7 +847,11 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
self.target_tp_rank = self.target_tp_ranks[0]
|
self.target_tp_rank = self.target_tp_ranks[0]
|
||||||
self.required_dst_info_num = 1
|
self.required_dst_info_num = 1
|
||||||
|
|
||||||
self.target_dp_group = self.bootstrap_room % self.prefill_dp_size
|
if self.data_parallel_rank is not None:
|
||||||
|
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
||||||
|
self.target_dp_group = self.data_parallel_rank
|
||||||
|
else:
|
||||||
|
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||||
|
|
||||||
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||||
bootstrap_key = (
|
bootstrap_key = (
|
||||||
|
|||||||
@@ -407,9 +407,10 @@ class NixlKVReceiver(CommonKVReceiver):
|
|||||||
mgr: NixlKVManager,
|
mgr: NixlKVManager,
|
||||||
bootstrap_addr: str,
|
bootstrap_addr: str,
|
||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
):
|
):
|
||||||
self.started_transfer = False
|
self.started_transfer = False
|
||||||
super().__init__(mgr, bootstrap_addr, bootstrap_room)
|
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
||||||
|
|
||||||
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
||||||
for bootstrap_info in self.bootstrap_infos:
|
for bootstrap_info in self.bootstrap_infos:
|
||||||
|
|||||||
@@ -23,6 +23,12 @@ class EngineBase(ABC):
|
|||||||
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
|
||||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None,
|
||||||
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
custom_logit_processor: Optional[Union[List[str], str]] = None,
|
||||||
|
return_hidden_states: Optional[bool] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
bootstrap_host: Optional[Union[List[str], str]] = None,
|
||||||
|
bootstrap_port: Optional[Union[List[int], int]] = None,
|
||||||
|
bootstrap_room: Optional[Union[List[int], int]] = None,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
) -> Union[Dict, Iterator[Dict]]:
|
) -> Union[Dict, Iterator[Dict]]:
|
||||||
"""Generate outputs based on given inputs."""
|
"""Generate outputs based on given inputs."""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -167,11 +167,22 @@ class Engine(EngineBase):
|
|||||||
bootstrap_host: Optional[Union[List[str], str]] = None,
|
bootstrap_host: Optional[Union[List[str], str]] = None,
|
||||||
bootstrap_port: Optional[Union[List[int], int]] = None,
|
bootstrap_port: Optional[Union[List[int], int]] = None,
|
||||||
bootstrap_room: Optional[Union[List[int], int]] = None,
|
bootstrap_room: Optional[Union[List[int], int]] = None,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
) -> Union[Dict, Iterator[Dict]]:
|
) -> Union[Dict, Iterator[Dict]]:
|
||||||
"""
|
"""
|
||||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
||||||
Please refer to `GenerateReqInput` for the documentation.
|
Please refer to `GenerateReqInput` for the documentation.
|
||||||
"""
|
"""
|
||||||
|
if self.server_args.enable_dp_attention:
|
||||||
|
if data_parallel_rank is None:
|
||||||
|
logger.info("data_parallel_rank not provided, using default dispatch")
|
||||||
|
elif data_parallel_rank < 0:
|
||||||
|
raise ValueError("data_parallel_rank must be non-negative")
|
||||||
|
elif data_parallel_rank >= self.server_args.dp_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"data_parallel_rank must be less than dp_size: {self.server_args.dp_size}"
|
||||||
|
)
|
||||||
|
|
||||||
obj = GenerateReqInput(
|
obj = GenerateReqInput(
|
||||||
text=prompt,
|
text=prompt,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -188,6 +199,7 @@ class Engine(EngineBase):
|
|||||||
bootstrap_host=bootstrap_host,
|
bootstrap_host=bootstrap_host,
|
||||||
bootstrap_port=bootstrap_port,
|
bootstrap_port=bootstrap_port,
|
||||||
bootstrap_room=bootstrap_room,
|
bootstrap_room=bootstrap_room,
|
||||||
|
data_parallel_rank=data_parallel_rank,
|
||||||
)
|
)
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||||
@@ -237,11 +249,24 @@ class Engine(EngineBase):
|
|||||||
bootstrap_host: Optional[Union[List[str], str]] = None,
|
bootstrap_host: Optional[Union[List[str], str]] = None,
|
||||||
bootstrap_port: Optional[Union[List[int], int]] = None,
|
bootstrap_port: Optional[Union[List[int], int]] = None,
|
||||||
bootstrap_room: Optional[Union[List[int], int]] = None,
|
bootstrap_room: Optional[Union[List[int], int]] = None,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
) -> Union[Dict, AsyncIterator[Dict]]:
|
) -> Union[Dict, AsyncIterator[Dict]]:
|
||||||
"""
|
"""
|
||||||
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
|
||||||
Please refer to `GenerateReqInput` for the documentation.
|
Please refer to `GenerateReqInput` for the documentation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if self.server_args.enable_dp_attention:
|
||||||
|
if data_parallel_rank is None:
|
||||||
|
logger.info("data_parallel_rank not provided, using default dispatch")
|
||||||
|
elif data_parallel_rank < 0:
|
||||||
|
raise ValueError("data_parallel_rank must be non-negative")
|
||||||
|
elif data_parallel_rank >= self.server_args.dp_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"data_parallel_rank: {data_parallel_rank}")
|
||||||
obj = GenerateReqInput(
|
obj = GenerateReqInput(
|
||||||
text=prompt,
|
text=prompt,
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
@@ -257,6 +282,7 @@ class Engine(EngineBase):
|
|||||||
bootstrap_host=bootstrap_host,
|
bootstrap_host=bootstrap_host,
|
||||||
bootstrap_port=bootstrap_port,
|
bootstrap_port=bootstrap_port,
|
||||||
bootstrap_room=bootstrap_room,
|
bootstrap_room=bootstrap_room,
|
||||||
|
data_parallel_rank=data_parallel_rank,
|
||||||
)
|
)
|
||||||
generator = self.tokenizer_manager.generate_request(obj, None)
|
generator = self.tokenizer_manager.generate_request(obj, None)
|
||||||
|
|
||||||
|
|||||||
@@ -248,10 +248,18 @@ class DataParallelController:
|
|||||||
|
|
||||||
def round_robin_scheduler(self, req: Req):
|
def round_robin_scheduler(self, req: Req):
|
||||||
if self.server_args.disaggregation_mode == "null":
|
if self.server_args.disaggregation_mode == "null":
|
||||||
|
if req.data_parallel_rank is not None:
|
||||||
|
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
||||||
|
self.workers[req.data_parallel_rank].send_pyobj(req)
|
||||||
|
else:
|
||||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||||
self.workers
|
self.workers
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
if req.data_parallel_rank is not None:
|
||||||
|
logger.debug(f"Direct routing to DP rank {req.data_parallel_rank}")
|
||||||
|
self.workers[req.data_parallel_rank].send_pyobj(req)
|
||||||
else:
|
else:
|
||||||
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
|
||||||
|
|
||||||
|
|||||||
@@ -106,6 +106,9 @@ class GenerateReqInput:
|
|||||||
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
||||||
bootstrap_room: Optional[Union[List[int], int]] = None
|
bootstrap_room: Optional[Union[List[int], int]] = None
|
||||||
|
|
||||||
|
# For data parallel rank routing
|
||||||
|
data_parallel_rank: Optional[int] = None
|
||||||
|
|
||||||
def contains_mm_input(self) -> bool:
|
def contains_mm_input(self) -> bool:
|
||||||
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
return has_valid_data(self.image_data) or has_valid_data(self.audio_data)
|
||||||
|
|
||||||
@@ -417,6 +420,9 @@ class GenerateReqInput:
|
|||||||
bootstrap_room=(
|
bootstrap_room=(
|
||||||
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
self.bootstrap_room[i] if self.bootstrap_room is not None else None
|
||||||
),
|
),
|
||||||
|
data_parallel_rank=(
|
||||||
|
self.data_parallel_rank if self.data_parallel_rank is not None else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -464,6 +470,9 @@ class TokenizedGenerateReqInput:
|
|||||||
bootstrap_port: Optional[int] = None
|
bootstrap_port: Optional[int] = None
|
||||||
bootstrap_room: Optional[int] = None
|
bootstrap_room: Optional[int] = None
|
||||||
|
|
||||||
|
# For data parallel rank routing
|
||||||
|
data_parallel_rank: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmbeddingReqInput:
|
class EmbeddingReqInput:
|
||||||
|
|||||||
@@ -451,6 +451,7 @@ class Req:
|
|||||||
bootstrap_host: Optional[str] = None,
|
bootstrap_host: Optional[str] = None,
|
||||||
bootstrap_port: Optional[int] = None,
|
bootstrap_port: Optional[int] = None,
|
||||||
bootstrap_room: Optional[int] = None,
|
bootstrap_room: Optional[int] = None,
|
||||||
|
data_parallel_rank: Optional[int] = None,
|
||||||
):
|
):
|
||||||
# Input and output info
|
# Input and output info
|
||||||
self.rid = rid
|
self.rid = rid
|
||||||
@@ -605,6 +606,9 @@ class Req:
|
|||||||
self.bootstrap_room: Optional[int] = bootstrap_room
|
self.bootstrap_room: Optional[int] = bootstrap_room
|
||||||
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
self.disagg_kv_sender: Optional[BaseKVSender] = None
|
||||||
|
|
||||||
|
# For data parallel rank routing
|
||||||
|
self.data_parallel_rank: Optional[int] = data_parallel_rank
|
||||||
|
|
||||||
# the start index of the sent kv cache
|
# the start index of the sent kv cache
|
||||||
# We want to send it chunk by chunk for chunked prefill.
|
# We want to send it chunk by chunk for chunked prefill.
|
||||||
# After every chunk forward, we do the following:
|
# After every chunk forward, we do the following:
|
||||||
|
|||||||
@@ -949,6 +949,7 @@ class Scheduler(
|
|||||||
bootstrap_host=recv_req.bootstrap_host,
|
bootstrap_host=recv_req.bootstrap_host,
|
||||||
bootstrap_port=recv_req.bootstrap_port,
|
bootstrap_port=recv_req.bootstrap_port,
|
||||||
bootstrap_room=recv_req.bootstrap_room,
|
bootstrap_room=recv_req.bootstrap_room,
|
||||||
|
data_parallel_rank=recv_req.data_parallel_rank,
|
||||||
)
|
)
|
||||||
req.tokenizer = self.tokenizer
|
req.tokenizer = self.tokenizer
|
||||||
|
|
||||||
|
|||||||
@@ -570,6 +570,7 @@ class TokenizerManager:
|
|||||||
session_params=session_params,
|
session_params=session_params,
|
||||||
custom_logit_processor=obj.custom_logit_processor,
|
custom_logit_processor=obj.custom_logit_processor,
|
||||||
return_hidden_states=obj.return_hidden_states,
|
return_hidden_states=obj.return_hidden_states,
|
||||||
|
data_parallel_rank=obj.data_parallel_rank,
|
||||||
)
|
)
|
||||||
elif isinstance(obj, EmbeddingReqInput):
|
elif isinstance(obj, EmbeddingReqInput):
|
||||||
tokenized_obj = TokenizedEmbeddingReqInput(
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
||||||
|
|||||||
Reference in New Issue
Block a user