[PD] Support logprob & Add failure test (#6558)
This commit is contained in:
@@ -48,6 +48,7 @@ from sglang.srt.disaggregation.prefill import (
|
||||
)
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
MetadataBuffers,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
prepare_abort,
|
||||
@@ -569,20 +570,13 @@ class Scheduler(
|
||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
aux_dtype = torch.int32
|
||||
# A list of metadata buffers. The shape is (b, metadata_size) where
|
||||
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
||||
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
||||
output_id_buffer = torch.zeros(
|
||||
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
||||
)
|
||||
metadata_buffers = [output_id_buffer]
|
||||
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
||||
|
||||
# The decode requests polling kv cache
|
||||
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
||||
gloo_group=self.attn_tp_cpu_group,
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
scheduler=self,
|
||||
tree_cache=self.tree_cache,
|
||||
)
|
||||
@@ -597,8 +591,7 @@ class Scheduler(
|
||||
else self.draft_worker.model_runner.token_to_kv_pool
|
||||
),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
aux_dtype=aux_dtype,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
scheduler=self,
|
||||
transfer_queue=self.disagg_decode_transfer_queue,
|
||||
tree_cache=self.tree_cache,
|
||||
@@ -618,14 +611,7 @@ class Scheduler(
|
||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
aux_dtype = torch.int32
|
||||
# A list of metadata buffers. The shape is (b, metadata_size) where
|
||||
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
||||
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
||||
output_id_buffer = torch.zeros(
|
||||
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
||||
)
|
||||
metadata_buffers = [output_id_buffer]
|
||||
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
||||
|
||||
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
|
||||
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
||||
@@ -635,8 +621,7 @@ class Scheduler(
|
||||
else self.draft_worker.model_runner.token_to_kv_pool
|
||||
),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
aux_dtype=aux_dtype,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
|
||||
Reference in New Issue
Block a user