[PD] Support logprob & Add failure test (#6558)
This commit is contained in:
@@ -36,6 +36,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
FakeBootstrapHost,
|
||||
KVClassType,
|
||||
MetadataBuffers,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
get_kv_class,
|
||||
@@ -78,8 +79,7 @@ class DecodePreallocQueue:
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
draft_token_to_kv_pool: Optional[KVCache],
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: List[torch.Tensor],
|
||||
aux_dtype: torch.dtype,
|
||||
metadata_buffers: MetadataBuffers,
|
||||
scheduler: Scheduler,
|
||||
transfer_queue: DecodeTransferQueue,
|
||||
tree_cache: BasePrefixCache,
|
||||
@@ -94,7 +94,6 @@ class DecodePreallocQueue:
|
||||
self.token_to_kv_pool = token_to_kv_pool_allocator.get_kvcache()
|
||||
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
||||
self.is_mla_backend = is_mla_backend(self.token_to_kv_pool)
|
||||
self.aux_dtype = aux_dtype
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
self.scheduler = scheduler
|
||||
@@ -133,15 +132,9 @@ class DecodePreallocQueue:
|
||||
kv_args.kv_data_lens = kv_data_lens
|
||||
kv_args.kv_item_lens = kv_item_lens
|
||||
|
||||
kv_args.aux_data_ptrs = [
|
||||
output_id_tensor.data_ptr() for output_id_tensor in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_data_lens = [
|
||||
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_item_lens = [
|
||||
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
||||
self.metadata_buffers.get_buf_infos()
|
||||
)
|
||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||
kv_args.gpu_id = self.scheduler.gpu_id
|
||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||
@@ -211,7 +204,18 @@ class DecodePreallocQueue:
|
||||
indices_to_remove = set()
|
||||
allocatable_tokens = self._allocatable_tokens()
|
||||
|
||||
# First, remove all failed requests from the queue
|
||||
for i, decode_req in enumerate(self.queue):
|
||||
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
||||
self.scheduler.stream_output(
|
||||
[decode_req.req], decode_req.req.return_logprob
|
||||
)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
for i, decode_req in enumerate(self.queue):
|
||||
if i in indices_to_remove:
|
||||
continue
|
||||
|
||||
if not decode_req.waiting_for_input:
|
||||
continue
|
||||
|
||||
@@ -331,7 +335,7 @@ class DecodeTransferQueue:
|
||||
self,
|
||||
gloo_group: ProcessGroup,
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: torch.Tensor,
|
||||
metadata_buffers: MetadataBuffers,
|
||||
scheduler: Scheduler,
|
||||
tree_cache: BasePrefixCache,
|
||||
):
|
||||
@@ -342,11 +346,11 @@ class DecodeTransferQueue:
|
||||
self.scheduler = scheduler
|
||||
self.tree_cache = tree_cache
|
||||
|
||||
def add(self, req_conn: DecodeRequest) -> None:
|
||||
self.queue.append(req_conn)
|
||||
def add(self, decode_req: DecodeRequest) -> None:
|
||||
self.queue.append(decode_req)
|
||||
|
||||
def extend(self, req_conns) -> None:
|
||||
self.queue.extend(req_conns)
|
||||
def extend(self, decode_reqs: List[DecodeRequest]) -> None:
|
||||
self.queue.extend(decode_reqs)
|
||||
|
||||
def pop_transferred(self) -> List[DecodeRequest]:
|
||||
if not self.queue:
|
||||
@@ -356,14 +360,6 @@ class DecodeTransferQueue:
|
||||
[decode_req.kv_receiver for decode_req in self.queue], self.gloo_group
|
||||
)
|
||||
|
||||
# First, remove all failed requests from the queue
|
||||
for i, decode_req in enumerate(self.queue):
|
||||
if isinstance(decode_req.req.finished_reason, FINISH_ABORT):
|
||||
self.scheduler.stream_output(
|
||||
[decode_req.req], decode_req.req.return_logprob
|
||||
)
|
||||
indices_to_remove.add(i)
|
||||
|
||||
transferred_reqs = []
|
||||
indices_to_remove = set()
|
||||
for i, (decode_req, poll) in enumerate(zip(self.queue, polls)):
|
||||
@@ -387,16 +383,37 @@ class DecodeTransferQueue:
|
||||
indices_to_remove.add(i)
|
||||
continue
|
||||
elif poll == KVPoll.Success:
|
||||
# pop and push it to waiting queue
|
||||
|
||||
idx = decode_req.metadata_buffer_index
|
||||
assert len(decode_req.req.output_ids) == 0
|
||||
output_id_buffer = self.metadata_buffers[0]
|
||||
# the last dimension is padded by the same values.
|
||||
output_id = output_id_buffer[idx][0].item()
|
||||
assert len(decode_req.req.output_ids) == 0
|
||||
assert decode_req.req.transferred_output_id is None
|
||||
decode_req.req.transferred_output_id = output_id
|
||||
transferred_reqs.append(decode_req)
|
||||
(
|
||||
output_id,
|
||||
output_token_logprobs_val,
|
||||
output_token_logprobs_idx,
|
||||
output_top_logprobs_val,
|
||||
output_top_logprobs_idx,
|
||||
) = self.metadata_buffers.get_buf(idx)
|
||||
|
||||
decode_req.req.output_ids.append(output_id[0].item())
|
||||
|
||||
if decode_req.req.return_logprob:
|
||||
decode_req.req.output_token_logprobs_val.append(
|
||||
output_token_logprobs_val[0].item()
|
||||
)
|
||||
decode_req.req.output_token_logprobs_idx.append(
|
||||
output_token_logprobs_idx[0].item()
|
||||
)
|
||||
decode_req.req.output_top_logprobs_val.append(
|
||||
output_top_logprobs_val[
|
||||
: decode_req.req.top_logprobs_num
|
||||
].tolist()
|
||||
)
|
||||
decode_req.req.output_top_logprobs_idx.append(
|
||||
output_top_logprobs_idx[
|
||||
: decode_req.req.top_logprobs_num
|
||||
].tolist()
|
||||
)
|
||||
|
||||
transferred_reqs.append(decode_req.req)
|
||||
indices_to_remove.add(i)
|
||||
elif poll in [
|
||||
KVPoll.Bootstrapping,
|
||||
@@ -451,7 +468,9 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
# Generate fake extend output.
|
||||
if batch.forward_mode.is_extend():
|
||||
# Note: Logprobs should be handled on the prefill engine.
|
||||
self.stream_output(batch.reqs, False)
|
||||
self.stream_output(
|
||||
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
||||
)
|
||||
if prepare_dp_attn_flag:
|
||||
self._prepare_idle_batch_and_run(None)
|
||||
else:
|
||||
@@ -497,7 +516,9 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
# Generate fake extend output.
|
||||
if batch.forward_mode.is_extend():
|
||||
# Note: Logprobs should be handled on the prefill engine.
|
||||
self.stream_output(batch.reqs, False)
|
||||
self.stream_output(
|
||||
batch.reqs, any(req.return_logprob for req in batch.reqs)
|
||||
)
|
||||
if prepare_dp_attn_flag:
|
||||
batch_, result = self._prepare_idle_batch_and_run(
|
||||
None, delay_process=True
|
||||
@@ -618,15 +639,8 @@ class SchedulerDisaggregationDecodeMixin:
|
||||
|
||||
def process_decode_queue(self: Scheduler):
|
||||
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
||||
|
||||
def _num_pre_alloc(req):
|
||||
return len(req.req.origin_input_ids) + max(len(req.req.output_ids) - 1, 0)
|
||||
|
||||
self.num_tokens_pre_allocated += sum(_num_pre_alloc(req) for req in req_conns)
|
||||
self.disagg_decode_transfer_queue.extend(req_conns)
|
||||
alloc_reqs = (
|
||||
self.disagg_decode_transfer_queue.pop_transferred()
|
||||
) # the requests which kv has arrived
|
||||
self.num_tokens_pre_allocated -= sum(_num_pre_alloc(req) for req in alloc_reqs)
|
||||
|
||||
self.waiting_queue.extend([req.req for req in alloc_reqs])
|
||||
self.waiting_queue.extend(alloc_reqs)
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -76,6 +76,11 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
||||
self.out_cache_loc = out_cache_loc
|
||||
self.seq_lens_sum = sum(seq_lens)
|
||||
|
||||
if self.return_logprob:
|
||||
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
||||
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
||||
|
||||
self.extend_num_tokens = extend_num_tokens
|
||||
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
||||
self.extend_lens = [r.extend_input_len for r in reqs]
|
||||
@@ -94,12 +99,41 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
"""Assign the buffered last input id to schedule batch"""
|
||||
self.output_ids = []
|
||||
for req in self.reqs:
|
||||
if req.output_ids and len(req.output_ids) > 0:
|
||||
# resumed retracted req
|
||||
self.output_ids.append(req.output_ids[-1])
|
||||
else:
|
||||
assert req.transferred_output_id is not None
|
||||
req.output_ids.append(req.transferred_output_id)
|
||||
self.output_ids.append(req.transferred_output_id)
|
||||
self.output_ids.append(req.output_ids[-1])
|
||||
self.tree_cache.cache_unfinished_req(req)
|
||||
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
||||
|
||||
# Simulate the eagle run. We add mock data to hidden states for the
|
||||
# ease of implementation now meaning the first token will have acc rate
|
||||
# of 0.
|
||||
if not self.spec_algorithm.is_none():
|
||||
|
||||
b = len(self.reqs)
|
||||
topk_p = torch.arange(
|
||||
b * server_args.speculative_eagle_topk,
|
||||
0,
|
||||
-1,
|
||||
device=self.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
|
||||
topk_p /= b * server_args.speculative_eagle_topk
|
||||
topk_index = torch.arange(
|
||||
b * server_args.speculative_eagle_topk, device=self.device
|
||||
)
|
||||
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
||||
|
||||
# local import to avoid circular import
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
topk_p=topk_p,
|
||||
topk_index=topk_index,
|
||||
hidden_states=torch.ones(
|
||||
(b, model_config.hidden_size), device=self.device
|
||||
),
|
||||
verified_id=self.output_ids,
|
||||
)
|
||||
spec_info.prepare_for_extend(self)
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
self.spec_info = spec_info
|
||||
|
||||
@@ -73,11 +73,27 @@ class MiniLoadBalancer:
|
||||
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
||||
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
||||
]
|
||||
|
||||
# Wait for both responses to complete. Prefill should end first.
|
||||
_, decode_response = await asyncio.gather(*tasks)
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
if "return_logprob" in modified_request:
|
||||
|
||||
prefill_json = await prefill_response.json()
|
||||
ret_json = await decode_response.json()
|
||||
|
||||
# merge `meta_info.input_token_logprobs` from prefill to decode
|
||||
if "meta_info" in ret_json:
|
||||
if "input_token_logprobs" in ret_json["meta_info"]:
|
||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||
prefill_json["meta_info"]["input_token_logprobs"]
|
||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
else:
|
||||
ret_json = await decode_response.json()
|
||||
|
||||
return ORJSONResponse(
|
||||
content=await decode_response.json(),
|
||||
content=ret_json,
|
||||
status_code=decode_response.status,
|
||||
)
|
||||
|
||||
@@ -92,30 +108,47 @@ class MiniLoadBalancer:
|
||||
total=3600
|
||||
) # Add timeout for request reliability
|
||||
) as session:
|
||||
try:
|
||||
# Create the tasks for both prefill and decode requests
|
||||
tasks = [
|
||||
session.post(
|
||||
f"{prefill_server}/{endpoint}", json=modified_request
|
||||
),
|
||||
session.post(
|
||||
f"{decode_server}/{endpoint}", json=modified_request
|
||||
),
|
||||
]
|
||||
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
# Create the tasks for both prefill and decode requests
|
||||
tasks = [
|
||||
session.post(f"{prefill_server}/generate", json=modified_request),
|
||||
session.post(f"{decode_server}/generate", json=modified_request),
|
||||
]
|
||||
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
||||
prefill_response, decode_response = await asyncio.gather(*tasks)
|
||||
|
||||
if modified_request.get("return_logprob", False):
|
||||
prefill_chunks = []
|
||||
async for chunk in prefill_response.content:
|
||||
prefill_chunks.append(chunk)
|
||||
|
||||
first_prefill_chunk = (
|
||||
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
||||
)
|
||||
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
||||
|
||||
async for chunk in decode_response.content:
|
||||
# Note: This is inefficient
|
||||
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
||||
decoded_chunk = chunk.decode("utf-8")
|
||||
if (
|
||||
decoded_chunk
|
||||
and decoded_chunk.startswith("data:")
|
||||
and "[DONE]" not in decoded_chunk
|
||||
):
|
||||
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
||||
ret_json["meta_info"]["input_token_logprobs"] = (
|
||||
first_prefill_chunk_json["meta_info"][
|
||||
"input_token_logprobs"
|
||||
]
|
||||
+ ret_json["meta_info"]["input_token_logprobs"]
|
||||
)
|
||||
|
||||
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
||||
else:
|
||||
yield chunk
|
||||
else:
|
||||
async for chunk in decode_response.content:
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
error_msg = {
|
||||
"error": {"message": f"Stream processing error: {str(e)}"}
|
||||
}
|
||||
yield b"data: " + orjson.dumps(
|
||||
error_msg, option=orjson.OPT_NON_STR_KEYS
|
||||
) + b"\n\n"
|
||||
finally:
|
||||
if prefill_response is not None:
|
||||
await prefill_response.release()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_results(),
|
||||
|
||||
@@ -32,6 +32,7 @@ from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
FakeBootstrapHost,
|
||||
KVClassType,
|
||||
MetadataBuffers,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
get_kv_class,
|
||||
@@ -63,8 +64,7 @@ class PrefillBootstrapQueue:
|
||||
token_to_kv_pool: KVCache,
|
||||
draft_token_to_kv_pool: Optional[KVCache],
|
||||
req_to_metadata_buffer_idx_allocator: ReqToMetadataIdxAllocator,
|
||||
metadata_buffers: List[torch.Tensor],
|
||||
aux_dtype: torch.dtype,
|
||||
metadata_buffers: MetadataBuffers,
|
||||
tp_rank: int,
|
||||
tp_size: int,
|
||||
bootstrap_port: int,
|
||||
@@ -76,7 +76,6 @@ class PrefillBootstrapQueue:
|
||||
self.draft_token_to_kv_pool = draft_token_to_kv_pool
|
||||
|
||||
self.is_mla_backend = is_mla_backend(token_to_kv_pool)
|
||||
self.aux_dtype = aux_dtype
|
||||
|
||||
self.metadata_buffers = metadata_buffers
|
||||
self.req_to_metadata_buffer_idx_allocator = req_to_metadata_buffer_idx_allocator
|
||||
@@ -116,15 +115,9 @@ class PrefillBootstrapQueue:
|
||||
kv_args.kv_item_lens = kv_item_lens
|
||||
|
||||
# Define req -> input ids buffer
|
||||
kv_args.aux_data_ptrs = [
|
||||
metadata_buffer.data_ptr() for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_data_lens = [
|
||||
metadata_buffer.nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_item_lens = [
|
||||
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
||||
]
|
||||
kv_args.aux_data_ptrs, kv_args.aux_data_lens, kv_args.aux_item_lens = (
|
||||
self.metadata_buffers.get_buf_infos()
|
||||
)
|
||||
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
||||
kv_args.gpu_id = self.scheduler.gpu_id
|
||||
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
||||
@@ -299,10 +292,9 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
|
||||
Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||
Adapted from process_batch_result_prefill
|
||||
"""
|
||||
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
@@ -315,27 +307,78 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
result.extend_logprob_start_len_per_req,
|
||||
)
|
||||
|
||||
logprob_pt = 0
|
||||
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
|
||||
if self.enable_overlap:
|
||||
# wait
|
||||
_, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(launch_done)
|
||||
logits_output, next_token_ids, _ = self.tp_worker.resolve_last_batch_result(
|
||||
launch_done
|
||||
)
|
||||
else:
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
|
||||
for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
|
||||
if batch.return_logprob:
|
||||
if logits_output.next_token_logprobs is not None:
|
||||
logits_output.next_token_logprobs = (
|
||||
logits_output.next_token_logprobs.tolist()
|
||||
)
|
||||
if logits_output.input_token_logprobs is not None:
|
||||
logits_output.input_token_logprobs = tuple(
|
||||
logits_output.input_token_logprobs.tolist()
|
||||
)
|
||||
for i, (req, next_token_id) in enumerate(
|
||||
zip(batch.reqs, next_token_ids, strict=True)
|
||||
):
|
||||
req: Req
|
||||
if req.is_chunked <= 0:
|
||||
# There is no output_ids for prefill
|
||||
req.output_ids.append(next_token_id)
|
||||
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
||||
self.send_kv_chunk(req, token_id=next_token_id)
|
||||
self.disagg_prefill_inflight_queue.append(req)
|
||||
if req.return_logprob:
|
||||
assert extend_logprob_start_len_per_req is not None
|
||||
assert extend_input_len_per_req is not None
|
||||
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
||||
extend_input_len = extend_input_len_per_req[i]
|
||||
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
||||
self.add_logprob_return_values(
|
||||
i,
|
||||
req,
|
||||
logprob_pt,
|
||||
next_token_ids,
|
||||
num_input_logprobs,
|
||||
logits_output,
|
||||
)
|
||||
logprob_pt += num_input_logprobs
|
||||
self.send_kv_chunk(req, last_chunk=True)
|
||||
|
||||
if req.grammar is not None:
|
||||
req.grammar.accept_token(next_token_id)
|
||||
req.grammar.finished = req.finished()
|
||||
else:
|
||||
# being chunked reqs' prefill is not finished
|
||||
req.is_chunked -= 1
|
||||
|
||||
if req.return_logprob:
|
||||
extend_logprob_start_len = extend_logprob_start_len_per_req[i]
|
||||
extend_input_len = extend_input_len_per_req[i]
|
||||
if extend_logprob_start_len < extend_input_len:
|
||||
# Update input logprobs.
|
||||
num_input_logprobs = extend_input_len - extend_logprob_start_len
|
||||
self.add_input_logprob_return_values(
|
||||
i,
|
||||
req,
|
||||
logits_output,
|
||||
logprob_pt,
|
||||
num_input_logprobs,
|
||||
last_prefill_chunk=False,
|
||||
)
|
||||
logprob_pt += num_input_logprobs
|
||||
|
||||
if self.enable_overlap:
|
||||
self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
|
||||
self.send_kv_chunk(req, last_chunk=False, end_idx=req.tmp_end_idx)
|
||||
|
||||
# We need to remove the sync in the following function for overlap schedule.
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
|
||||
def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
|
||||
"""
|
||||
@@ -379,7 +422,11 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
)
|
||||
|
||||
# Stream requests which have finished transfer
|
||||
self.stream_output(done_reqs, False, None)
|
||||
self.stream_output(
|
||||
done_reqs,
|
||||
any(req.return_logprob for req in done_reqs),
|
||||
None,
|
||||
)
|
||||
|
||||
self.disagg_prefill_inflight_queue = undone_reqs
|
||||
|
||||
@@ -405,7 +452,7 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
def send_kv_chunk(
|
||||
self: Scheduler,
|
||||
req: Req,
|
||||
token_id: Optional[int] = None,
|
||||
last_chunk: bool = False,
|
||||
end_idx: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -413,37 +460,28 @@ class SchedulerDisaggregationPrefillMixin:
|
||||
"""
|
||||
page_size = self.token_to_kv_pool_allocator.page_size
|
||||
start_idx = req.start_send_idx
|
||||
# if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
|
||||
# the resolved length is not the same as fill_ids's length
|
||||
end_idx = (
|
||||
end_idx
|
||||
if end_idx is not None
|
||||
else min(len(req.fill_ids), len(req.origin_input_ids))
|
||||
)
|
||||
|
||||
last_chunk = token_id is not None
|
||||
|
||||
if not last_chunk:
|
||||
# if not the last chunk and the last page is partial, delay the last partial page to the next send
|
||||
end_idx = end_idx - end_idx % page_size
|
||||
|
||||
# Update next start_send_idx
|
||||
req.start_send_idx = end_idx
|
||||
|
||||
kv_indices = (
|
||||
self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
if last_chunk is True:
|
||||
self.disagg_prefill_bootstrap_queue.store_prefill_results(
|
||||
req.metadata_buffer_index, token_id
|
||||
)
|
||||
req.start_send_idx = end_idx
|
||||
if last_chunk:
|
||||
self.disagg_metadata_buffers.set_buf(req)
|
||||
page_indices = kv_to_page_indices(kv_indices, page_size)
|
||||
if len(page_indices) == 0:
|
||||
logger.info(
|
||||
f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
|
||||
)
|
||||
return
|
||||
|
||||
req.disagg_kv_sender.send(page_indices)
|
||||
|
||||
@@ -6,7 +6,7 @@ import random
|
||||
import warnings
|
||||
from collections import deque
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
@@ -15,6 +15,9 @@ import torch.distributed as dist
|
||||
|
||||
from sglang.srt.utils import get_ip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
|
||||
FakeBootstrapHost = "2.2.2.2"
|
||||
|
||||
# env var for testing failure, convert to float explicitly
|
||||
@@ -196,3 +199,83 @@ def prepare_abort(req: Req, error_message: str, status_code=None):
|
||||
req.input_top_logprobs_idx = []
|
||||
req.input_token_ids_logprobs_val = []
|
||||
req.input_token_ids_logprobs_idx = []
|
||||
|
||||
|
||||
class MetadataBuffers:
|
||||
def __init__(self, size: int, max_top_logprobs_num: int = 128):
|
||||
# TODO: abort top_logprobs_num > 128 in PD
|
||||
|
||||
# We transfer the metadata of first output token to decode
|
||||
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
||||
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device="cpu")
|
||||
self.output_token_logprobs_val = torch.zeros(
|
||||
(size, 16), dtype=torch.float32, device="cpu"
|
||||
)
|
||||
self.output_token_logprobs_idx = torch.zeros(
|
||||
(size, 16), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
self.output_top_logprobs_val = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.float32, device="cpu"
|
||||
)
|
||||
self.output_top_logprobs_idx = torch.zeros(
|
||||
(size, max_top_logprobs_num), dtype=torch.int32, device="cpu"
|
||||
)
|
||||
|
||||
def get_buf_infos(self):
|
||||
ptrs = [
|
||||
self.output_ids.data_ptr(),
|
||||
self.output_token_logprobs_val.data_ptr(),
|
||||
self.output_token_logprobs_idx.data_ptr(),
|
||||
self.output_top_logprobs_val.data_ptr(),
|
||||
self.output_top_logprobs_idx.data_ptr(),
|
||||
]
|
||||
data_lens = [
|
||||
self.output_ids.nbytes,
|
||||
self.output_token_logprobs_val.nbytes,
|
||||
self.output_token_logprobs_idx.nbytes,
|
||||
self.output_top_logprobs_val.nbytes,
|
||||
self.output_top_logprobs_idx.nbytes,
|
||||
]
|
||||
item_lens = [
|
||||
self.output_ids[0].nbytes,
|
||||
self.output_token_logprobs_val[0].nbytes,
|
||||
self.output_token_logprobs_idx[0].nbytes,
|
||||
self.output_top_logprobs_val[0].nbytes,
|
||||
self.output_top_logprobs_idx[0].nbytes,
|
||||
]
|
||||
return ptrs, data_lens, item_lens
|
||||
|
||||
def get_buf(self, idx: int):
|
||||
return (
|
||||
self.output_ids[idx],
|
||||
self.output_token_logprobs_val[idx],
|
||||
self.output_token_logprobs_idx[idx],
|
||||
self.output_top_logprobs_val[idx],
|
||||
self.output_top_logprobs_idx[idx],
|
||||
)
|
||||
|
||||
def set_buf(self, req: Req):
|
||||
|
||||
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
||||
if req.return_logprob:
|
||||
if req.output_token_logprobs_val: # not none or empty list
|
||||
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
||||
req.output_token_logprobs_val[0]
|
||||
)
|
||||
if req.output_token_logprobs_idx: # not none or empty list
|
||||
self.output_token_logprobs_idx[req.metadata_buffer_index][0] = (
|
||||
req.output_token_logprobs_idx[0]
|
||||
)
|
||||
|
||||
if req.output_top_logprobs_val: # not none or empty list
|
||||
self.output_top_logprobs_val[req.metadata_buffer_index][
|
||||
: len(req.output_top_logprobs_val[0])
|
||||
] = torch.tensor(
|
||||
req.output_top_logprobs_val[0], dtype=torch.float32, device="cpu"
|
||||
)
|
||||
if req.output_top_logprobs_idx: # not none or empty list
|
||||
self.output_top_logprobs_idx[req.metadata_buffer_index][
|
||||
: len(req.output_top_logprobs_idx[0])
|
||||
] = torch.tensor(
|
||||
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
||||
)
|
||||
|
||||
@@ -607,9 +607,6 @@ class Req:
|
||||
self.tmp_end_idx: int = -1
|
||||
self.metadata_buffer_index: int = -1
|
||||
|
||||
# The first output_id transferred from prefill instance.
|
||||
self.transferred_output_id: Optional[int] = None
|
||||
|
||||
@property
|
||||
def seqlen(self):
|
||||
return len(self.origin_input_ids) + len(self.output_ids)
|
||||
|
||||
@@ -48,6 +48,7 @@ from sglang.srt.disaggregation.prefill import (
|
||||
)
|
||||
from sglang.srt.disaggregation.utils import (
|
||||
DisaggregationMode,
|
||||
MetadataBuffers,
|
||||
ReqToMetadataIdxAllocator,
|
||||
TransferBackend,
|
||||
prepare_abort,
|
||||
@@ -569,20 +570,13 @@ class Scheduler(
|
||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
aux_dtype = torch.int32
|
||||
# A list of metadata buffers. The shape is (b, metadata_size) where
|
||||
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
||||
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
||||
output_id_buffer = torch.zeros(
|
||||
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
||||
)
|
||||
metadata_buffers = [output_id_buffer]
|
||||
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
||||
|
||||
# The decode requests polling kv cache
|
||||
self.disagg_decode_transfer_queue = DecodeTransferQueue(
|
||||
gloo_group=self.attn_tp_cpu_group,
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
scheduler=self,
|
||||
tree_cache=self.tree_cache,
|
||||
)
|
||||
@@ -597,8 +591,7 @@ class Scheduler(
|
||||
else self.draft_worker.model_runner.token_to_kv_pool
|
||||
),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
aux_dtype=aux_dtype,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
scheduler=self,
|
||||
transfer_queue=self.disagg_decode_transfer_queue,
|
||||
tree_cache=self.tree_cache,
|
||||
@@ -618,14 +611,7 @@ class Scheduler(
|
||||
req_to_metadata_buffer_idx_allocator = ReqToMetadataIdxAllocator(
|
||||
buffer_size
|
||||
)
|
||||
aux_dtype = torch.int32
|
||||
# A list of metadata buffers. The shape is (b, metadata_size) where
|
||||
# b corresponds to a max running requests. The last shape * dtype.itemsize
|
||||
# should be larger than 64 bytes to work with RDMA, so we pad it.
|
||||
output_id_buffer = torch.zeros(
|
||||
(buffer_size, 16), dtype=aux_dtype, device="cpu"
|
||||
)
|
||||
metadata_buffers = [output_id_buffer]
|
||||
self.disagg_metadata_buffers = MetadataBuffers(buffer_size)
|
||||
|
||||
self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
|
||||
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
||||
@@ -635,8 +621,7 @@ class Scheduler(
|
||||
else self.draft_worker.model_runner.token_to_kv_pool
|
||||
),
|
||||
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
|
||||
metadata_buffers=metadata_buffers,
|
||||
aux_dtype=aux_dtype,
|
||||
metadata_buffers=self.disagg_metadata_buffers,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
|
||||
|
||||
@@ -485,7 +485,6 @@ def popen_launch_pd_server(
|
||||
api_key: Optional[str] = None,
|
||||
other_args: list[str] = (),
|
||||
env: Optional[dict] = None,
|
||||
return_stdout_stderr: Optional[tuple] = None,
|
||||
):
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
@@ -515,42 +514,9 @@ def popen_launch_pd_server(
|
||||
|
||||
print(f"command={' '.join(command)}")
|
||||
|
||||
if return_stdout_stderr:
|
||||
process = subprocess.Popen(
|
||||
command,
|
||||
stdout=return_stdout_stderr[0],
|
||||
stderr=return_stdout_stderr[1],
|
||||
env=env,
|
||||
text=True,
|
||||
)
|
||||
else:
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None, env=env)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
with requests.Session() as session:
|
||||
while time.perf_counter() - start_time < timeout:
|
||||
try:
|
||||
headers = {
|
||||
"Content-Type": "application/json; charset=utf-8",
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
response = session.get(
|
||||
f"{base_url}/health",
|
||||
headers=headers,
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return process
|
||||
except requests.RequestException:
|
||||
pass
|
||||
|
||||
return_code = process.poll()
|
||||
if return_code is not None:
|
||||
raise Exception(f"Server unexpectedly exits ({return_code=}).")
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
kill_process_tree(process.pid)
|
||||
raise TimeoutError("Server failed to start within the timeout period.")
|
||||
return process
|
||||
|
||||
|
||||
def run_with_timeout(
|
||||
|
||||
22
scripts/playground/disaggregation/cli-logprob.py
Normal file
22
scripts/playground/disaggregation/cli-logprob.py
Normal file
@@ -0,0 +1,22 @@
|
||||
prompt = "The capital of taiwan is "
|
||||
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
response = requests.post(
|
||||
"http://0.0.0.0:8000/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"sampling_params": {"temperature": 0},
|
||||
"return_logprob": True,
|
||||
"return_input_logprob": True,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
|
||||
j = response.json()
|
||||
input_logprobs = j["meta_info"]["input_token_logprobs"]
|
||||
output_logprobs = j["meta_info"]["output_token_logprobs"]
|
||||
|
||||
print(len(input_logprobs), len(output_logprobs))
|
||||
@@ -1,7 +1,9 @@
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
@@ -25,15 +27,22 @@ class TestDisaggregationAccuracy(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
cls.base_host = "127.0.0.1"
|
||||
cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1])
|
||||
cls.lb_url = DEFAULT_URL_FOR_TEST
|
||||
cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}"
|
||||
cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}"
|
||||
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.base_host = parsed_url.hostname
|
||||
base_port = str(parsed_url.port)
|
||||
cls.lb_port = base_port
|
||||
cls.prefill_port = f"{int(base_port) + 100}"
|
||||
cls.decode_port = f"{int(base_port) + 200}"
|
||||
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
|
||||
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
|
||||
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
|
||||
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
|
||||
|
||||
run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
|
||||
run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
|
||||
# Non blocking start servers
|
||||
cls.start_prefill()
|
||||
cls.start_decode()
|
||||
|
||||
# Block until both
|
||||
cls.wait_server_ready(cls.prefill_url + "/health")
|
||||
cls.wait_server_ready(cls.decode_url + "/health")
|
||||
|
||||
@@ -48,7 +57,7 @@ class TestDisaggregationAccuracy(CustomTestCase):
|
||||
"--host",
|
||||
cls.base_host,
|
||||
"--port",
|
||||
str(cls.base_port),
|
||||
cls.lb_port,
|
||||
]
|
||||
|
||||
print("Starting load balancer:", " ".join(lb_command))
|
||||
@@ -63,14 +72,10 @@ class TestDisaggregationAccuracy(CustomTestCase):
|
||||
"--trust-remote-code",
|
||||
"--disaggregation-mode",
|
||||
"prefill",
|
||||
"--host",
|
||||
cls.base_host,
|
||||
"--port",
|
||||
str(cls.base_port + 100),
|
||||
"--tp",
|
||||
"4",
|
||||
# "--disaggregation-ib-device",
|
||||
# "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3",
|
||||
"1",
|
||||
"--disaggregation-ib-device",
|
||||
"mlx5_roce0",
|
||||
]
|
||||
cls.process_prefill = popen_launch_pd_server(
|
||||
cls.model,
|
||||
@@ -85,16 +90,12 @@ class TestDisaggregationAccuracy(CustomTestCase):
|
||||
"--trust-remote-code",
|
||||
"--disaggregation-mode",
|
||||
"decode",
|
||||
"--host",
|
||||
cls.base_host,
|
||||
"--port",
|
||||
str(cls.base_port + 200),
|
||||
"--tp",
|
||||
"4",
|
||||
"1",
|
||||
"--base-gpu-id",
|
||||
"4",
|
||||
# "--disaggregation-ib-device",
|
||||
# "mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
|
||||
"1",
|
||||
"--disaggregation-ib-device",
|
||||
"mlx5_roce1",
|
||||
]
|
||||
cls.process_decode = popen_launch_pd_server(
|
||||
cls.model,
|
||||
@@ -128,6 +129,9 @@ class TestDisaggregationAccuracy(CustomTestCase):
|
||||
except Exception as e:
|
||||
print(f"Error killing process {process.pid}: {e}")
|
||||
|
||||
# wait for 5 seconds
|
||||
time.sleep(5)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
@@ -135,45 +139,63 @@ class TestDisaggregationAccuracy(CustomTestCase):
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.lb_url.split(":")[-1]),
|
||||
host=f"http://{self.base_host}",
|
||||
port=int(self.lb_port),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(f"Evaluation metrics: {metrics}")
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.62)
|
||||
|
||||
def test_logprob(self):
|
||||
prompt = "The capital of taiwan is "
|
||||
response = requests.post(
|
||||
self.lb_url + "/generate",
|
||||
json={
|
||||
"text": prompt,
|
||||
"sampling_params": {"temperature": 0},
|
||||
"return_logprob": True,
|
||||
"return_input_logprob": True,
|
||||
"logprob_start_len": 0,
|
||||
},
|
||||
)
|
||||
|
||||
class TestDisaggregationSpecAccuracy(CustomTestCase):
|
||||
j = response.json()
|
||||
completion_tokens = j["meta_info"]["completion_tokens"]
|
||||
input_logprobs = j["meta_info"]["input_token_logprobs"]
|
||||
output_logprobs = j["meta_info"]["output_token_logprobs"]
|
||||
|
||||
assert (
|
||||
len(output_logprobs) == completion_tokens
|
||||
), f"output_logprobs and completion_tokens should have the same length, but got {len(output_logprobs)} and {completion_tokens}"
|
||||
assert (
|
||||
len(input_logprobs) > 0
|
||||
), f"input_logprobs should have at least one token, but got {len(input_logprobs)}"
|
||||
|
||||
|
||||
class TestDisaggregationMooncakeFailure(CustomTestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
|
||||
cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
|
||||
cls.base_host = "127.0.0.1"
|
||||
cls.base_port = int(DEFAULT_URL_FOR_TEST.split(":")[-1])
|
||||
cls.lb_url = DEFAULT_URL_FOR_TEST
|
||||
cls.prefill_url = f"http://{cls.base_host}:{cls.base_port + 100}"
|
||||
cls.decode_url = f"http://{cls.base_host}:{cls.base_port + 200}"
|
||||
cls.spec_args = [
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-draft-model-path",
|
||||
cls.draft_model,
|
||||
"--speculative-num-steps",
|
||||
"3",
|
||||
"--speculative-eagle-topk",
|
||||
"4",
|
||||
"--speculative-num-draft-tokens",
|
||||
"16",
|
||||
"--cuda-graph-max-bs",
|
||||
"8",
|
||||
]
|
||||
# set DISAGGREGATION_TEST_FAILURE_PROB to simulate failure
|
||||
os.environ["DISAGGREGATION_TEST_FAILURE_PROB"] = "0.05"
|
||||
|
||||
run_with_timeout(cls.start_prefill, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
|
||||
run_with_timeout(cls.start_decode, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH)
|
||||
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.base_host = parsed_url.hostname
|
||||
base_port = str(parsed_url.port)
|
||||
cls.lb_port = base_port
|
||||
cls.prefill_port = f"{int(base_port) + 100}"
|
||||
cls.decode_port = f"{int(base_port) + 200}"
|
||||
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
|
||||
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
|
||||
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
|
||||
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
|
||||
|
||||
# Non blocking start servers
|
||||
cls.start_prefill()
|
||||
cls.start_decode()
|
||||
|
||||
# Block until both
|
||||
cls.wait_server_ready(cls.prefill_url + "/health")
|
||||
cls.wait_server_ready(cls.decode_url + "/health")
|
||||
|
||||
@@ -188,7 +210,149 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
|
||||
"--host",
|
||||
cls.base_host,
|
||||
"--port",
|
||||
str(cls.base_port),
|
||||
cls.lb_port,
|
||||
]
|
||||
|
||||
print("Starting load balancer:", " ".join(lb_command))
|
||||
cls.process_lb = subprocess.Popen(
|
||||
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||
)
|
||||
cls.wait_server_ready(cls.lb_url + "/health")
|
||||
|
||||
@classmethod
|
||||
def start_prefill(cls):
|
||||
prefill_args = [
|
||||
"--trust-remote-code",
|
||||
"--disaggregation-mode",
|
||||
"prefill",
|
||||
"--tp",
|
||||
"1",
|
||||
"--disaggregation-ib-device",
|
||||
"mlx5_roce0",
|
||||
]
|
||||
cls.process_prefill = popen_launch_pd_server(
|
||||
cls.model,
|
||||
cls.prefill_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=prefill_args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def start_decode(cls):
|
||||
decode_args = [
|
||||
"--trust-remote-code",
|
||||
"--disaggregation-mode",
|
||||
"decode",
|
||||
"--tp",
|
||||
"1",
|
||||
"--base-gpu-id",
|
||||
"1",
|
||||
"--disaggregation-ib-device",
|
||||
"mlx5_roce1",
|
||||
]
|
||||
cls.process_decode = popen_launch_pd_server(
|
||||
cls.model,
|
||||
cls.decode_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=decode_args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def wait_server_ready(cls, url, timeout=60):
|
||||
start_time = time.perf_counter()
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
print(f"Server {url} is ready")
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if time.perf_counter() - start_time > timeout:
|
||||
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
|
||||
time.sleep(1)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
# unset DISAGGREGATION_TEST_FAILURE_PROB
|
||||
os.environ.pop("DISAGGREGATION_TEST_FAILURE_PROB")
|
||||
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
|
||||
if process:
|
||||
try:
|
||||
kill_process_tree(process.pid)
|
||||
except Exception as e:
|
||||
print(f"Error killing process {process.pid}: {e}")
|
||||
|
||||
# wait for 5 seconds
|
||||
time.sleep(5)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host=f"http://{self.base_host}",
|
||||
port=int(self.lb_port),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(f"Evaluation metrics: {metrics}")
|
||||
# Expect lots of failure but the server cannot crash
|
||||
|
||||
|
||||
class TestDisaggregationMooncakeSpec(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
|
||||
cls.draft_model = DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST
|
||||
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||
cls.base_host = parsed_url.hostname
|
||||
base_port = str(parsed_url.port)
|
||||
cls.lb_port = base_port
|
||||
cls.prefill_port = f"{int(base_port) + 100}"
|
||||
cls.decode_port = f"{int(base_port) + 200}"
|
||||
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
|
||||
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
|
||||
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
|
||||
cls.spec_args = [
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-draft-model-path",
|
||||
cls.draft_model,
|
||||
"--speculative-num-steps",
|
||||
"3",
|
||||
"--speculative-eagle-topk",
|
||||
"4",
|
||||
"--speculative-num-draft-tokens",
|
||||
"16",
|
||||
"--cuda-graph-max-bs",
|
||||
"8",
|
||||
]
|
||||
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
|
||||
|
||||
# Non blocking start servers
|
||||
cls.start_prefill()
|
||||
cls.start_decode()
|
||||
|
||||
# Block until both
|
||||
cls.wait_server_ready(cls.prefill_url + "/health")
|
||||
cls.wait_server_ready(cls.decode_url + "/health")
|
||||
|
||||
lb_command = [
|
||||
"python3",
|
||||
"-m",
|
||||
"sglang.srt.disaggregation.mini_lb",
|
||||
"--prefill",
|
||||
cls.prefill_url,
|
||||
"--decode",
|
||||
cls.decode_url,
|
||||
"--host",
|
||||
cls.base_host,
|
||||
"--port",
|
||||
cls.lb_port,
|
||||
]
|
||||
|
||||
print("Starting load balancer:", " ".join(lb_command))
|
||||
@@ -215,21 +379,15 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
|
||||
|
||||
@classmethod
|
||||
def start_prefill(cls):
|
||||
|
||||
prefill_args = [
|
||||
"--trust-remote-code",
|
||||
"--disaggregation-mode",
|
||||
"prefill",
|
||||
"--host",
|
||||
cls.base_host,
|
||||
"--port",
|
||||
str(cls.base_port + 100),
|
||||
"--tp",
|
||||
"4",
|
||||
# "--disaggregation-ib-device",
|
||||
# "mlx5_roce0,mlx5_roce1,mlx5_roce2,mlx5_roce3",
|
||||
"2",
|
||||
"--disaggregation-ib-device",
|
||||
"mlx5_roce0,mlx5_roce1",
|
||||
] + cls.spec_args
|
||||
|
||||
cls.process_prefill = popen_launch_pd_server(
|
||||
cls.model,
|
||||
cls.prefill_url,
|
||||
@@ -243,16 +401,12 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
|
||||
"--trust-remote-code",
|
||||
"--disaggregation-mode",
|
||||
"decode",
|
||||
"--host",
|
||||
cls.base_host,
|
||||
"--port",
|
||||
str(cls.base_port + 200),
|
||||
"--tp",
|
||||
"4",
|
||||
"2",
|
||||
"--base-gpu-id",
|
||||
"4",
|
||||
# "--disaggregation-ib-device",
|
||||
# "mlx5_roce4,mlx5_roce5,mlx5_roce6,mlx5_roce7",
|
||||
"2",
|
||||
"--disaggregation-ib-device",
|
||||
"mlx5_roce2,mlx5_roce3",
|
||||
] + cls.spec_args
|
||||
cls.process_decode = popen_launch_pd_server(
|
||||
cls.model,
|
||||
@@ -261,15 +415,43 @@ class TestDisaggregationSpecAccuracy(CustomTestCase):
|
||||
other_args=decode_args,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def wait_server_ready(cls, url, timeout=60):
|
||||
start_time = time.perf_counter()
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
print(f"Server {url} is ready")
|
||||
return
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if time.perf_counter() - start_time > timeout:
|
||||
raise RuntimeError(f"Server {url} failed to start in {timeout}s")
|
||||
time.sleep(1)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
for process in [cls.process_lb, cls.process_decode, cls.process_prefill]:
|
||||
if process:
|
||||
try:
|
||||
kill_process_tree(process.pid)
|
||||
except Exception as e:
|
||||
print(f"Error killing process {process.pid}: {e}")
|
||||
|
||||
# wait for 5 seconds
|
||||
time.sleep(5)
|
||||
|
||||
def test_gsm8k(self):
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=None,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=4, # TODO: 128 crashes the decode
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.lb_url.split(":")[-1]),
|
||||
parallel=2,
|
||||
host=f"http://{self.base_host}",
|
||||
port=int(self.lb_port),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(f"Evaluation metrics: {metrics}")
|
||||
|
||||
Reference in New Issue
Block a user