[PD] use int32 for kv indices & get num_reserved_decode_tokens from server_args (#7214)
This commit is contained in:
@@ -70,7 +70,7 @@ class BaseKVSender(ABC):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def send(self, kv_indices: npt.NDArray[np.int64]):
|
||||
def send(self, kv_indices: npt.NDArray[np.int32]):
|
||||
"""
|
||||
Send the kv cache at the given kv indices to the decoder server
|
||||
"""
|
||||
@@ -102,7 +102,7 @@ class BaseKVReceiver(ABC):
|
||||
): ...
|
||||
|
||||
@abstractmethod
|
||||
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
"""
|
||||
Notify the prefill server about the kv indices and aux index
|
||||
"""
|
||||
|
||||
@@ -26,8 +26,8 @@ class FastQueue:
|
||||
|
||||
|
||||
def group_concurrent_contiguous(
|
||||
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
||||
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
||||
src_indices: npt.NDArray[np.int32], dst_indices: npt.NDArray[np.int32]
|
||||
) -> Tuple[List[npt.NDArray[np.int32]], List[npt.NDArray[np.int32]]]:
|
||||
"""Vectorised NumPy implementation."""
|
||||
if src_indices.size == 0:
|
||||
return [], []
|
||||
|
||||
@@ -158,6 +158,7 @@ class DecodePreallocQueue:
|
||||
bootstrap_port: int,
|
||||
max_total_num_tokens: int,
|
||||
prefill_pp_size: int,
|
||||
num_reserved_decode_tokens: int,
|
||||
transfer_backend: TransferBackend,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
@@ -178,9 +179,7 @@ class DecodePreallocQueue:
|
||||
self.bootstrap_port = bootstrap_port
|
||||
self.max_total_num_tokens = max_total_num_tokens
|
||||
self.prefill_pp_size = prefill_pp_size
|
||||
self.num_reserved_decode_tokens = int(
|
||||
os.environ.get("SGLANG_NUM_RESERVED_DECODE_TOKENS", "512")
|
||||
)
|
||||
self.num_reserved_decode_tokens = num_reserved_decode_tokens
|
||||
self.transfer_backend = transfer_backend
|
||||
# Queue for requests pending pre-allocation
|
||||
self.queue: List[DecodeRequest] = []
|
||||
@@ -404,7 +403,6 @@ class DecodePreallocQueue:
|
||||
]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.astype(np.int64)
|
||||
)
|
||||
|
||||
decode_req.metadata_buffer_index = (
|
||||
|
||||
@@ -48,7 +48,7 @@ class FakeKVSender(BaseKVSender):
|
||||
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int64],
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
):
|
||||
self.has_sent = True
|
||||
logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
|
||||
|
||||
@@ -59,7 +59,7 @@ class KVTransferError(Exception):
|
||||
@dataclasses.dataclass
|
||||
class TransferKVChunk:
|
||||
room: int
|
||||
prefill_kv_indices: npt.NDArray[np.int64]
|
||||
prefill_kv_indices: npt.NDArray[np.int32]
|
||||
index_slice: slice
|
||||
is_last: bool
|
||||
prefill_aux_index: Optional[int]
|
||||
@@ -72,7 +72,7 @@ class TransferInfo:
|
||||
endpoint: str
|
||||
dst_port: int
|
||||
mooncake_session_id: str
|
||||
dst_kv_indices: npt.NDArray[np.int64]
|
||||
dst_kv_indices: npt.NDArray[np.int32]
|
||||
dst_aux_index: int
|
||||
required_dst_info_num: int
|
||||
is_dummy: bool
|
||||
@@ -81,10 +81,10 @@ class TransferInfo:
|
||||
def from_zmq(cls, msg: List[bytes]):
|
||||
if msg[4] == b"" and msg[5] == b"":
|
||||
is_dummy = True
|
||||
dst_kv_indices = np.array([], dtype=np.int64)
|
||||
dst_kv_indices = np.array([], dtype=np.int32)
|
||||
dst_aux_index = None
|
||||
else:
|
||||
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64)
|
||||
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
|
||||
dst_aux_index = int(msg[5].decode("ascii"))
|
||||
is_dummy = False
|
||||
return cls(
|
||||
@@ -233,9 +233,9 @@ class MooncakeKVManager(BaseKVManager):
|
||||
def send_kvcache(
|
||||
self,
|
||||
mooncake_session_id: str,
|
||||
prefill_kv_indices: npt.NDArray[np.int64],
|
||||
prefill_kv_indices: npt.NDArray[np.int32],
|
||||
dst_kv_ptrs: list[int],
|
||||
dst_kv_indices: npt.NDArray[np.int64],
|
||||
dst_kv_indices: npt.NDArray[np.int32],
|
||||
executor: concurrent.futures.ThreadPoolExecutor,
|
||||
):
|
||||
# Group by indices
|
||||
@@ -545,7 +545,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
def add_transfer_request(
|
||||
self,
|
||||
bootstrap_room: int,
|
||||
kv_indices: npt.NDArray[np.int64],
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
index_slice: slice,
|
||||
is_last: bool,
|
||||
aux_index: Optional[int] = None,
|
||||
@@ -701,7 +701,7 @@ class MooncakeKVSender(BaseKVSender):
|
||||
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int64],
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
):
|
||||
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
||||
self.curr_idx += len(kv_indices)
|
||||
@@ -971,7 +971,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
cls._socket_locks[endpoint] = threading.Lock()
|
||||
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
self.prefill_server_url = (
|
||||
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
||||
|
||||
@@ -44,7 +44,7 @@ class TransferInfo:
|
||||
agent_metadata: bytes
|
||||
agent_name: str
|
||||
dst_kv_ptrs: list[int]
|
||||
dst_kv_indices: npt.NDArray[np.int64]
|
||||
dst_kv_indices: npt.NDArray[np.int32]
|
||||
dst_aux_ptrs: list[int]
|
||||
dst_aux_index: int
|
||||
dst_gpu_id: int
|
||||
@@ -62,7 +62,7 @@ class TransferInfo:
|
||||
agent_metadata=msg[3],
|
||||
agent_name=msg[4].decode("ascii"),
|
||||
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
||||
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
|
||||
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
|
||||
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
|
||||
dst_aux_index=int(msg[8].decode("ascii")),
|
||||
dst_gpu_id=int(msg[9].decode("ascii")),
|
||||
@@ -162,9 +162,9 @@ class NixlKVManager(CommonKVManager):
|
||||
def send_kvcache(
|
||||
self,
|
||||
peer_name: str,
|
||||
prefill_kv_indices: npt.NDArray[np.int64],
|
||||
prefill_kv_indices: npt.NDArray[np.int32],
|
||||
dst_kv_ptrs: list[int],
|
||||
dst_kv_indices: npt.NDArray[np.int64],
|
||||
dst_kv_indices: npt.NDArray[np.int32],
|
||||
dst_gpu_id: int,
|
||||
notif: str,
|
||||
):
|
||||
@@ -246,7 +246,7 @@ class NixlKVManager(CommonKVManager):
|
||||
def add_transfer_request(
|
||||
self,
|
||||
bootstrap_room: int,
|
||||
kv_indices: npt.NDArray[np.int64],
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
index_slice: slice,
|
||||
is_last: bool,
|
||||
chunk_id: int,
|
||||
@@ -373,7 +373,7 @@ class NixlKVSender(BaseKVSender):
|
||||
|
||||
def send(
|
||||
self,
|
||||
kv_indices: npt.NDArray[np.int64],
|
||||
kv_indices: npt.NDArray[np.int32],
|
||||
):
|
||||
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
||||
self.curr_idx += len(kv_indices)
|
||||
@@ -417,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
self.started_transfer = False
|
||||
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
||||
|
||||
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
||||
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
||||
for bootstrap_info in self.bootstrap_infos:
|
||||
self.prefill_server_url = (
|
||||
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
||||
|
||||
@@ -576,7 +576,6 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
||||
.cpu()
|
||||
.numpy()
|
||||
.astype(np.int64)
|
||||
)
|
||||
req.start_send_idx = end_idx
|
||||
if last_chunk:
|
||||
|
||||
@@ -656,6 +656,7 @@ class Scheduler(
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
max_total_num_tokens=self.max_total_num_tokens,
|
||||
prefill_pp_size=self.server_args.disaggregation_prefill_pp,
|
||||
num_reserved_decode_tokens=self.server_args.num_reserved_decode_tokens,
|
||||
transfer_backend=self.transfer_backend,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user