[PD] Support PD disaggregation with Prefill PP (#8846)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com> Signed-off-by: Shangming Cai <csmthu@gmail.com> Co-authored-by: root <huzhiyuan@xiaohongshu.com> Co-authored-by: Ying Sheng <sqy1415@gmail.com> Co-authored-by: Francis <38564764+ssssnow@users.noreply.github.com> Co-authored-by: zitto <zhjc1124@gmail.com>
This commit is contained in:
@@ -30,6 +30,7 @@ class KVArgs:
|
|||||||
# for pp prefill
|
# for pp prefill
|
||||||
prefill_pp_size: int
|
prefill_pp_size: int
|
||||||
pp_rank: int
|
pp_rank: int
|
||||||
|
prefill_start_layer: int
|
||||||
# for system dp
|
# for system dp
|
||||||
system_dp_rank: int
|
system_dp_rank: int
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from sglang.srt.disaggregation.common.utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
|
from sglang.srt.distributed import get_pp_group
|
||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
get_attention_dp_rank,
|
get_attention_dp_rank,
|
||||||
get_attention_dp_size,
|
get_attention_dp_size,
|
||||||
@@ -180,6 +181,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
self.session_failures = defaultdict(int)
|
self.session_failures = defaultdict(int)
|
||||||
self.failed_sessions = set()
|
self.failed_sessions = set()
|
||||||
self.session_lock = threading.Lock()
|
self.session_lock = threading.Lock()
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
# Determine the number of threads to use for kv sender
|
# Determine the number of threads to use for kv sender
|
||||||
cpu_count = os.cpu_count()
|
cpu_count = os.cpu_count()
|
||||||
transfer_thread_pool_size = get_int_env_var(
|
transfer_thread_pool_size = get_int_env_var(
|
||||||
@@ -313,11 +315,11 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
layers_params = None
|
layers_params = None
|
||||||
|
|
||||||
# pp is not supported on the decode side yet
|
# pp is not supported on the decode side yet
|
||||||
|
start_layer = self.kv_args.prefill_start_layer
|
||||||
|
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
|
||||||
if self.is_mla_backend:
|
if self.is_mla_backend:
|
||||||
src_kv_ptrs = self.kv_args.kv_data_ptrs
|
src_kv_ptrs = self.kv_args.kv_data_ptrs
|
||||||
layers_per_pp_stage = len(src_kv_ptrs)
|
layers_per_pp_stage = len(src_kv_ptrs)
|
||||||
start_layer = self.pp_rank * layers_per_pp_stage
|
|
||||||
end_layer = start_layer + layers_per_pp_stage
|
|
||||||
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||||
layers_params = [
|
layers_params = [
|
||||||
@@ -330,17 +332,15 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
||||||
|
dst_num_total_layers = num_kv_layers * self.pp_size
|
||||||
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
||||||
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
||||||
layers_per_pp_stage = len(src_k_ptrs)
|
layers_per_pp_stage = len(src_k_ptrs)
|
||||||
start_layer = self.pp_rank * layers_per_pp_stage
|
|
||||||
end_layer = start_layer + layers_per_pp_stage
|
|
||||||
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||||
dst_v_ptrs = dst_kv_ptrs[
|
dst_v_ptrs = dst_kv_ptrs[
|
||||||
num_kv_layers + start_layer : num_kv_layers + end_layer
|
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
||||||
]
|
]
|
||||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||||
|
|
||||||
layers_params = [
|
layers_params = [
|
||||||
(
|
(
|
||||||
src_k_ptrs[layer_id],
|
src_k_ptrs[layer_id],
|
||||||
@@ -452,6 +452,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
|
|
||||||
# pp is not supported on the decode side yet
|
# pp is not supported on the decode side yet
|
||||||
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
||||||
|
dst_num_total_layers = num_kv_layers * self.pp_size
|
||||||
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
||||||
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
||||||
layers_per_pp_stage = len(src_k_ptrs)
|
layers_per_pp_stage = len(src_k_ptrs)
|
||||||
@@ -459,7 +460,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
end_layer = start_layer + layers_per_pp_stage
|
end_layer = start_layer + layers_per_pp_stage
|
||||||
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
||||||
dst_v_ptrs = dst_kv_ptrs[
|
dst_v_ptrs = dst_kv_ptrs[
|
||||||
num_kv_layers + start_layer : num_kv_layers + end_layer
|
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
||||||
]
|
]
|
||||||
|
|
||||||
# Calculate precise byte offset and length for the sub-slice within the token
|
# Calculate precise byte offset and length for the sub-slice within the token
|
||||||
@@ -612,7 +613,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
)
|
)
|
||||||
polls = []
|
polls = []
|
||||||
dst_ranks_infos = []
|
dst_ranks_infos = []
|
||||||
local_rank = self.kv_args.engine_rank
|
local_rank = self.attn_tp_rank * self.pp_size + self.pp_rank
|
||||||
for req in reqs_to_be_processed:
|
for req in reqs_to_be_processed:
|
||||||
if not req.is_dummy:
|
if not req.is_dummy:
|
||||||
# Early exit if the request has failed
|
# Early exit if the request has failed
|
||||||
@@ -695,13 +696,14 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if kv_chunk.is_last:
|
if kv_chunk.is_last:
|
||||||
# Only the last chunk we need to send the aux data
|
if self.pp_group.is_last_rank:
|
||||||
ret = self.send_aux(
|
# Only the last chunk we need to send the aux data
|
||||||
req.mooncake_session_id,
|
ret = self.send_aux(
|
||||||
kv_chunk.prefill_aux_index,
|
req.mooncake_session_id,
|
||||||
target_rank_registration_info.dst_aux_ptrs,
|
kv_chunk.prefill_aux_index,
|
||||||
req.dst_aux_index,
|
target_rank_registration_info.dst_aux_ptrs,
|
||||||
)
|
req.dst_aux_index,
|
||||||
|
)
|
||||||
polls.append(True if ret == 0 else False)
|
polls.append(True if ret == 0 else False)
|
||||||
dst_ranks_infos.append(
|
dst_ranks_infos.append(
|
||||||
(req.endpoint, req.dst_port, req.room)
|
(req.endpoint, req.dst_port, req.room)
|
||||||
@@ -798,10 +800,7 @@ class MooncakeKVManager(BaseKVManager):
|
|||||||
arrived_response_num = len(
|
arrived_response_num = len(
|
||||||
self.prefill_response_tracker[bootstrap_room]
|
self.prefill_response_tracker[bootstrap_room]
|
||||||
)
|
)
|
||||||
if (
|
if arrived_response_num == expected_response_num:
|
||||||
self.is_mla_backend
|
|
||||||
or arrived_response_num == expected_response_num
|
|
||||||
):
|
|
||||||
self.update_status(bootstrap_room, KVPoll.Success)
|
self.update_status(bootstrap_room, KVPoll.Success)
|
||||||
elif status == KVPoll.Failed:
|
elif status == KVPoll.Failed:
|
||||||
self.record_failure(
|
self.record_failure(
|
||||||
@@ -1183,7 +1182,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
||||||
)
|
)
|
||||||
self.required_dst_info_num = 1
|
self.required_dst_info_num = 1
|
||||||
self.required_prefill_response_num = 1
|
self.required_prefill_response_num = 1 * (
|
||||||
|
self.prefill_pp_size // self.kv_mgr.pp_size
|
||||||
|
)
|
||||||
self.target_tp_ranks = [self.target_tp_rank]
|
self.target_tp_ranks = [self.target_tp_rank]
|
||||||
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
|
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
|
||||||
if not self.kv_mgr.is_mla_backend:
|
if not self.kv_mgr.is_mla_backend:
|
||||||
@@ -1196,7 +1197,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
self.required_dst_info_num = (
|
self.required_dst_info_num = (
|
||||||
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
|
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
|
||||||
)
|
)
|
||||||
self.required_prefill_response_num = 1
|
self.required_prefill_response_num = 1 * (
|
||||||
|
self.prefill_pp_size // self.kv_mgr.pp_size
|
||||||
|
)
|
||||||
self.target_tp_ranks = [self.target_tp_rank]
|
self.target_tp_ranks = [self.target_tp_rank]
|
||||||
else:
|
else:
|
||||||
if not self.kv_mgr.is_mla_backend:
|
if not self.kv_mgr.is_mla_backend:
|
||||||
@@ -1219,9 +1222,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|||||||
# or the KVPoll will never be set correctly
|
# or the KVPoll will never be set correctly
|
||||||
self.target_tp_rank = self.target_tp_ranks[0]
|
self.target_tp_rank = self.target_tp_ranks[0]
|
||||||
self.required_dst_info_num = 1
|
self.required_dst_info_num = 1
|
||||||
self.required_prefill_response_num = (
|
if self.kv_mgr.is_mla_backend:
|
||||||
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
self.required_prefill_response_num = (
|
||||||
)
|
self.prefill_pp_size // self.kv_mgr.pp_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.required_prefill_response_num = (
|
||||||
|
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
||||||
|
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
|
||||||
|
|
||||||
if self.data_parallel_rank is not None:
|
if self.data_parallel_rank is not None:
|
||||||
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
||||||
@@ -1530,7 +1538,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|||||||
"rank_port": rank_port,
|
"rank_port": rank_port,
|
||||||
}
|
}
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Register prefill bootstrap: DP {dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
||||||
)
|
)
|
||||||
|
|
||||||
return web.Response(text="OK", status=200)
|
return web.Response(text="OK", status=200)
|
||||||
|
|||||||
@@ -43,8 +43,13 @@ from sglang.srt.disaggregation.utils import (
|
|||||||
prepare_abort,
|
prepare_abort,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
||||||
from sglang.srt.utils import require_mlp_sync
|
from sglang.srt.utils import (
|
||||||
|
DynamicGradMode,
|
||||||
|
broadcast_pyobj,
|
||||||
|
point_to_point_pyobj,
|
||||||
|
require_mlp_sync,
|
||||||
|
)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
@@ -107,6 +112,7 @@ class PrefillBootstrapQueue:
|
|||||||
kv_args.system_dp_rank = self.scheduler.dp_rank
|
kv_args.system_dp_rank = self.scheduler.dp_rank
|
||||||
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
|
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
|
||||||
kv_args.prefill_pp_size = self.pp_size
|
kv_args.prefill_pp_size = self.pp_size
|
||||||
|
kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer
|
||||||
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
||||||
self.token_to_kv_pool.get_contiguous_buf_infos()
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
||||||
)
|
)
|
||||||
@@ -208,8 +214,8 @@ class PrefillBootstrapQueue:
|
|||||||
polls = poll_and_all_reduce(
|
polls = poll_and_all_reduce(
|
||||||
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
||||||
)
|
)
|
||||||
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
|
||||||
|
|
||||||
|
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
||||||
if rids_to_check is not None:
|
if rids_to_check is not None:
|
||||||
# if req not in reqs_info_to_check, skip
|
# if req not in reqs_info_to_check, skip
|
||||||
if req.rid not in rids_to_check:
|
if req.rid not in rids_to_check:
|
||||||
@@ -395,7 +401,10 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
req.output_ids.append(next_token_id)
|
req.output_ids.append(next_token_id)
|
||||||
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
||||||
self.disagg_prefill_inflight_queue.append(req)
|
self.disagg_prefill_inflight_queue.append(req)
|
||||||
if logits_output.hidden_states is not None:
|
if (
|
||||||
|
logits_output is not None
|
||||||
|
and logits_output.hidden_states is not None
|
||||||
|
):
|
||||||
last_hidden_index = (
|
last_hidden_index = (
|
||||||
hidden_state_offset + extend_input_len_per_req[i] - 1
|
hidden_state_offset + extend_input_len_per_req[i] - 1
|
||||||
)
|
)
|
||||||
@@ -603,3 +612,250 @@ class SchedulerDisaggregationPrefillMixin:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
req.disagg_kv_sender.send(page_indices)
|
req.disagg_kv_sender.send(page_indices)
|
||||||
|
|
||||||
|
# PP
|
||||||
|
@DynamicGradMode()
|
||||||
|
def event_loop_pp_disagg_prefill(self: Scheduler):
|
||||||
|
"""
|
||||||
|
An event loop for the prefill server in pipeline parallelism.
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
1. Each stage runs in the same order and is notified by the previous stage.
|
||||||
|
2. Each send/recv operation is blocking and matched by the neighboring stage.
|
||||||
|
|
||||||
|
Regular Schedule:
|
||||||
|
====================================================================
|
||||||
|
Stage i | Stage i+1
|
||||||
|
send ith req | recv ith req
|
||||||
|
send ith proxy | recv ith proxy
|
||||||
|
send prev (i+1)th carry | recv prev (i+1)th carry
|
||||||
|
====================================================================
|
||||||
|
|
||||||
|
Prefill Server Schedule:
|
||||||
|
====================================================================
|
||||||
|
Stage i | Stage i+1
|
||||||
|
send ith req | recv ith req
|
||||||
|
send ith bootstrap req | recv ith bootstrap req
|
||||||
|
send ith transferred req | recv ith transferred req
|
||||||
|
send ith proxy | recv ith proxy
|
||||||
|
send prev (i+1)th carry | recv prev (i+1)th carry
|
||||||
|
send prev (i+1)th release req | recv prev (i+1)th release req
|
||||||
|
====================================================================
|
||||||
|
|
||||||
|
There are two additional elements compared to the regular schedule:
|
||||||
|
|
||||||
|
1. Bootstrap Requests:
|
||||||
|
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
|
||||||
|
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
|
||||||
|
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
|
||||||
|
|
||||||
|
2. Transferred Requests + Release Requests:
|
||||||
|
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
|
||||||
|
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
|
||||||
|
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
|
||||||
|
"""
|
||||||
|
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||||
|
|
||||||
|
mbs = [None] * self.pp_size
|
||||||
|
last_mbs = [None] * self.pp_size
|
||||||
|
self.running_mbs = [
|
||||||
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
||||||
|
]
|
||||||
|
bids = [None] * self.pp_size
|
||||||
|
pp_outputs: Optional[PPProxyTensors] = None
|
||||||
|
|
||||||
|
# Either success or failed
|
||||||
|
bootstrapped_rids: List[str] = []
|
||||||
|
transferred_rids: List[str] = []
|
||||||
|
release_rids: Optional[List[str]] = None
|
||||||
|
|
||||||
|
# transferred microbatch
|
||||||
|
tmbs = [None] * self.pp_size
|
||||||
|
|
||||||
|
ENABLE_RELEASE = True # For debug
|
||||||
|
|
||||||
|
while True:
|
||||||
|
server_is_idle = True
|
||||||
|
|
||||||
|
for mb_id in range(self.pp_size):
|
||||||
|
self.running_batch = self.running_mbs[mb_id]
|
||||||
|
self.last_batch = last_mbs[mb_id]
|
||||||
|
|
||||||
|
recv_reqs = self.recv_requests()
|
||||||
|
|
||||||
|
self.process_input_requests(recv_reqs)
|
||||||
|
|
||||||
|
if self.pp_group.is_first_rank:
|
||||||
|
# First rank, pop the bootstrap reqs from the bootstrap queue
|
||||||
|
bootstrapped_reqs, failed_reqs = (
|
||||||
|
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
||||||
|
return_failed_reqs=True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
|
||||||
|
req.rid for req in failed_reqs
|
||||||
|
]
|
||||||
|
self.waiting_queue.extend(bootstrapped_reqs)
|
||||||
|
else:
|
||||||
|
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
|
||||||
|
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
|
||||||
|
bootstrapped_reqs = (
|
||||||
|
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
||||||
|
rids_to_check=bootstrapped_rids
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.waiting_queue.extend(bootstrapped_reqs)
|
||||||
|
|
||||||
|
if self.pp_group.is_first_rank:
|
||||||
|
transferred_rids = self.get_transferred_rids()
|
||||||
|
# if other ranks,
|
||||||
|
else:
|
||||||
|
# 1. recv previous stage's transferred reqs info
|
||||||
|
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
|
||||||
|
# 2. get the current stage's transferred reqs info
|
||||||
|
curr_transferred_rids = self.get_transferred_rids()
|
||||||
|
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
|
||||||
|
transferred_rids = list(
|
||||||
|
set(prev_transferred_rids) & set(curr_transferred_rids)
|
||||||
|
)
|
||||||
|
|
||||||
|
tmbs[mb_id] = transferred_rids
|
||||||
|
|
||||||
|
self.process_prefill_chunk()
|
||||||
|
mbs[mb_id] = self.get_new_batch_prefill()
|
||||||
|
self.running_mbs[mb_id] = self.running_batch
|
||||||
|
|
||||||
|
self.cur_batch = mbs[mb_id]
|
||||||
|
if self.cur_batch:
|
||||||
|
server_is_idle = False
|
||||||
|
result = self.run_batch(self.cur_batch)
|
||||||
|
|
||||||
|
# send the outputs to the next step
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
if self.cur_batch:
|
||||||
|
next_token_ids, bids[mb_id] = (
|
||||||
|
result.next_token_ids,
|
||||||
|
result.bid,
|
||||||
|
)
|
||||||
|
pp_outputs = PPProxyTensors(
|
||||||
|
{
|
||||||
|
"next_token_ids": next_token_ids,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# send the output from the last round to let the next stage worker run post processing
|
||||||
|
self.pp_group.send_tensor_dict(
|
||||||
|
pp_outputs.tensors,
|
||||||
|
all_gather_group=self.attn_tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
if ENABLE_RELEASE:
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
|
||||||
|
release_rids = transferred_rids
|
||||||
|
# send to the first rank
|
||||||
|
self.send_pyobj_to_next_stage(release_rids)
|
||||||
|
|
||||||
|
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
||||||
|
next_mb_id = (mb_id + 1) % self.pp_size
|
||||||
|
next_pp_outputs = None
|
||||||
|
next_release_rids = None
|
||||||
|
|
||||||
|
if mbs[next_mb_id] is not None:
|
||||||
|
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
||||||
|
self.pp_group.recv_tensor_dict(
|
||||||
|
all_gather_group=self.attn_tp_group
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
||||||
|
output_result = GenerationBatchResult(
|
||||||
|
logits_output=None,
|
||||||
|
pp_hidden_states_proxy_tensors=None,
|
||||||
|
next_token_ids=next_pp_outputs["next_token_ids"],
|
||||||
|
extend_input_len_per_req=None,
|
||||||
|
extend_logprob_start_len_per_req=None,
|
||||||
|
bid=bids[next_mb_id],
|
||||||
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
||||||
|
)
|
||||||
|
self.process_batch_result_disagg_prefill(
|
||||||
|
mbs[next_mb_id], output_result
|
||||||
|
)
|
||||||
|
|
||||||
|
last_mbs[next_mb_id] = mbs[next_mb_id]
|
||||||
|
|
||||||
|
if ENABLE_RELEASE:
|
||||||
|
if tmbs[next_mb_id] is not None:
|
||||||
|
# recv consensus rids from the previous rank
|
||||||
|
next_release_rids = self.recv_pyobj_from_prev_stage()
|
||||||
|
self.process_disagg_prefill_inflight_queue(next_release_rids)
|
||||||
|
|
||||||
|
# carry the outputs to the next stage
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
if self.cur_batch:
|
||||||
|
bids[mb_id] = result.bid
|
||||||
|
if pp_outputs:
|
||||||
|
# send the outputs from the last round to let the next stage worker run post processing
|
||||||
|
self.pp_group.send_tensor_dict(
|
||||||
|
pp_outputs.tensors,
|
||||||
|
all_gather_group=self.attn_tp_group,
|
||||||
|
)
|
||||||
|
if ENABLE_RELEASE:
|
||||||
|
if release_rids is not None:
|
||||||
|
self.send_pyobj_to_next_stage(release_rids)
|
||||||
|
|
||||||
|
if not self.pp_group.is_last_rank:
|
||||||
|
# send out reqs to the next stage
|
||||||
|
self.send_pyobj_to_next_stage(recv_reqs)
|
||||||
|
self.send_pyobj_to_next_stage(bootstrapped_rids)
|
||||||
|
self.send_pyobj_to_next_stage(transferred_rids)
|
||||||
|
|
||||||
|
# send out proxy tensors to the next stage
|
||||||
|
if self.cur_batch:
|
||||||
|
self.pp_group.send_tensor_dict(
|
||||||
|
result.pp_hidden_states_proxy_tensors,
|
||||||
|
all_gather_group=self.attn_tp_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
pp_outputs = next_pp_outputs
|
||||||
|
release_rids = next_release_rids
|
||||||
|
|
||||||
|
self.running_batch.batch_is_full = False
|
||||||
|
|
||||||
|
if not ENABLE_RELEASE:
|
||||||
|
if len(self.disagg_prefill_inflight_queue) > 0:
|
||||||
|
self.process_disagg_prefill_inflight_queue()
|
||||||
|
|
||||||
|
# When the server is idle, self-check and re-init some states
|
||||||
|
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
|
||||||
|
self.check_memory()
|
||||||
|
self.check_tree_cache()
|
||||||
|
self.new_token_ratio = self.init_new_token_ratio
|
||||||
|
|
||||||
|
def send_pyobj_to_next_stage(self, data):
|
||||||
|
if self.attn_tp_rank == 0:
|
||||||
|
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||||
|
point_to_point_pyobj(
|
||||||
|
data,
|
||||||
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
|
self.world_group.device_group,
|
||||||
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
|
((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset,
|
||||||
|
)
|
||||||
|
|
||||||
|
def recv_pyobj_from_prev_stage(self):
|
||||||
|
if self.attn_tp_rank == 0:
|
||||||
|
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
||||||
|
data = point_to_point_pyobj(
|
||||||
|
[],
|
||||||
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
|
self.world_group.device_group,
|
||||||
|
((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset,
|
||||||
|
self.pp_rank * self.tp_size + dp_offset,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data = None
|
||||||
|
|
||||||
|
if self.tp_size != 1:
|
||||||
|
data = broadcast_pyobj(
|
||||||
|
data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0]
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|||||||
@@ -2579,7 +2579,10 @@ def run_scheduler_process(
|
|||||||
if scheduler.enable_overlap:
|
if scheduler.enable_overlap:
|
||||||
scheduler.event_loop_overlap_disagg_prefill()
|
scheduler.event_loop_overlap_disagg_prefill()
|
||||||
else:
|
else:
|
||||||
scheduler.event_loop_normal_disagg_prefill()
|
if server_args.pp_size > 1:
|
||||||
|
scheduler.event_loop_pp_disagg_prefill()
|
||||||
|
else:
|
||||||
|
scheduler.event_loop_normal_disagg_prefill()
|
||||||
|
|
||||||
elif disaggregation_mode == DisaggregationMode.DECODE:
|
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||||
if scheduler.enable_overlap:
|
if scheduler.enable_overlap:
|
||||||
|
|||||||
@@ -1,9 +1,16 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import Dict, List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||||
|
|
||||||
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
||||||
|
from sglang.srt.model_executor.forward_batch_info import PPProxyTensors
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -41,6 +48,57 @@ def validate_input_length(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_logprob_dict_from_result(result: GenerationBatchResult) -> dict:
|
||||||
|
|
||||||
|
logits_output = result.logits_output
|
||||||
|
assert logits_output is not None
|
||||||
|
|
||||||
|
return {
|
||||||
|
"extend_input_len_per_req": result.extend_input_len_per_req,
|
||||||
|
"extend_logprob_start_len_per_req": result.extend_logprob_start_len_per_req,
|
||||||
|
"next_token_logprobs": result.logits_output.next_token_logprobs,
|
||||||
|
"next_token_top_logprobs_val": result.logits_output.next_token_top_logprobs_val,
|
||||||
|
"next_token_top_logprobs_idx": result.logits_output.next_token_top_logprobs_idx,
|
||||||
|
"next_token_token_ids_logprobs_val": result.logits_output.next_token_token_ids_logprobs_val,
|
||||||
|
"next_token_token_ids_logprobs_idx": result.logits_output.next_token_token_ids_logprobs_idx,
|
||||||
|
"input_token_logprobs": result.logits_output.input_token_logprobs,
|
||||||
|
"input_top_logprobs_val": result.logits_output.input_top_logprobs_val,
|
||||||
|
"input_top_logprobs_idx": result.logits_output.input_top_logprobs_idx,
|
||||||
|
"input_token_ids_logprobs_val": result.logits_output.input_token_ids_logprobs_val,
|
||||||
|
"input_token_ids_logprobs_idx": result.logits_output.input_token_ids_logprobs_idx,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_logprob_from_pp_outputs(
|
||||||
|
next_pp_outputs: PPProxyTensors,
|
||||||
|
) -> tuple[LogitsProcessorOutput, list[int], list[int]]:
|
||||||
|
logits_output = LogitsProcessorOutput(
|
||||||
|
# Do not send logits and hidden states because they are large
|
||||||
|
next_token_logits=None,
|
||||||
|
hidden_states=None,
|
||||||
|
next_token_logprobs=next_pp_outputs["next_token_logprobs"],
|
||||||
|
next_token_top_logprobs_val=next_pp_outputs["next_token_top_logprobs_val"],
|
||||||
|
next_token_top_logprobs_idx=next_pp_outputs["next_token_top_logprobs_idx"],
|
||||||
|
next_token_token_ids_logprobs_val=next_pp_outputs[
|
||||||
|
"next_token_token_ids_logprobs_val"
|
||||||
|
],
|
||||||
|
next_token_token_ids_logprobs_idx=next_pp_outputs[
|
||||||
|
"next_token_token_ids_logprobs_idx"
|
||||||
|
],
|
||||||
|
input_token_logprobs=next_pp_outputs["input_token_logprobs"],
|
||||||
|
input_top_logprobs_val=next_pp_outputs["input_top_logprobs_val"],
|
||||||
|
input_top_logprobs_idx=next_pp_outputs["input_top_logprobs_idx"],
|
||||||
|
input_token_ids_logprobs_val=next_pp_outputs["input_token_ids_logprobs_val"],
|
||||||
|
input_token_ids_logprobs_idx=next_pp_outputs["input_token_ids_logprobs_idx"],
|
||||||
|
)
|
||||||
|
extend_input_len_per_req = next_pp_outputs["extend_input_len_per_req"]
|
||||||
|
extend_logprob_start_len_per_req = next_pp_outputs[
|
||||||
|
"extend_logprob_start_len_per_req"
|
||||||
|
]
|
||||||
|
|
||||||
|
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
|
||||||
|
|
||||||
|
|
||||||
class DPBalanceMeta:
|
class DPBalanceMeta:
|
||||||
"""
|
"""
|
||||||
This class will be use in scheduler and dp controller
|
This class will be use in scheduler and dp controller
|
||||||
|
|||||||
@@ -849,7 +849,7 @@ class MLATokenToKVPool(KVCache):
|
|||||||
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
cache_k_rope = cache_k_rope.view(self.store_dtype)
|
||||||
|
|
||||||
set_mla_kv_buffer_triton(
|
set_mla_kv_buffer_triton(
|
||||||
self.kv_buffer[layer_id], loc, cache_k_nope, cache_k_rope
|
self.kv_buffer[layer_id - self.start_layer], loc, cache_k_nope, cache_k_rope
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_cpu_copy(self, indices):
|
def get_cpu_copy(self, indices):
|
||||||
|
|||||||
@@ -307,8 +307,13 @@ class ModelRunner:
|
|||||||
self.start_layer = getattr(self.model, "start_layer", 0)
|
self.start_layer = getattr(self.model, "start_layer", 0)
|
||||||
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
|
self.end_layer = getattr(self.model, "end_layer", model_num_layers)
|
||||||
self.num_effective_layers = self.end_layer - self.start_layer
|
self.num_effective_layers = self.end_layer - self.start_layer
|
||||||
assert (not model_has_mtp_layers) or (
|
assert (
|
||||||
self.num_effective_layers == model_num_layers
|
(not model_has_mtp_layers)
|
||||||
|
or (self.spec_algorithm.is_none())
|
||||||
|
or (
|
||||||
|
(not self.spec_algorithm.is_none())
|
||||||
|
and (self.num_effective_layers == model_num_layers)
|
||||||
|
)
|
||||||
), "PP is not compatible with MTP models."
|
), "PP is not compatible with MTP models."
|
||||||
|
|
||||||
# Apply torchao quantization
|
# Apply torchao quantization
|
||||||
@@ -1048,8 +1053,6 @@ class ModelRunner:
|
|||||||
else:
|
else:
|
||||||
num_layers = self.num_effective_layers
|
num_layers = self.num_effective_layers
|
||||||
if self.use_mla_backend:
|
if self.use_mla_backend:
|
||||||
# FIXME: pipeline parallelism is not compatible with mla backend
|
|
||||||
assert self.pp_size == 1
|
|
||||||
cell_size = (
|
cell_size = (
|
||||||
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
|
||||||
* num_layers
|
* num_layers
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||||
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
@@ -135,6 +135,8 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
|
# if not set, model load will be broken in DeepseekV3ForCausalLM load_weights()
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
self.determine_num_fused_shared_experts("DeepseekV3ForCausalLMNextN")
|
self.determine_num_fused_shared_experts("DeepseekV3ForCausalLMNextN")
|
||||||
|
|
||||||
self.model = DeepseekModelNextN(
|
self.model = DeepseekModelNextN(
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import concurrent.futures
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -30,6 +30,7 @@ from transformers import PretrainedConfig
|
|||||||
|
|
||||||
from sglang.srt.distributed import (
|
from sglang.srt.distributed import (
|
||||||
get_moe_expert_parallel_world_size,
|
get_moe_expert_parallel_world_size,
|
||||||
|
get_pp_group,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
parallel_state,
|
parallel_state,
|
||||||
tensor_model_parallel_all_reduce,
|
tensor_model_parallel_all_reduce,
|
||||||
@@ -87,13 +88,13 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
||||||
from sglang.srt.layers.utils import is_sm100_supported
|
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id, is_sm100_supported
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead,
|
ParallelLMHead,
|
||||||
VocabParallelEmbedding,
|
VocabParallelEmbedding,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.two_batch_overlap import (
|
from sglang.srt.two_batch_overlap import (
|
||||||
MaybeTboDeepEPDispatcher,
|
MaybeTboDeepEPDispatcher,
|
||||||
@@ -114,6 +115,7 @@ from sglang.srt.utils import (
|
|||||||
is_hip,
|
is_hip,
|
||||||
is_non_idle_and_non_empty,
|
is_non_idle_and_non_empty,
|
||||||
log_info_on_rank0,
|
log_info_on_rank0,
|
||||||
|
make_layers,
|
||||||
use_intel_amx_backend,
|
use_intel_amx_backend,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -2029,26 +2031,35 @@ class DeepseekV2Model(nn.Module):
|
|||||||
self.padding_id = config.pad_token_id
|
self.padding_id = config.pad_token_id
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.first_k_dense_replace = config.first_k_dense_replace
|
self.first_k_dense_replace = config.first_k_dense_replace
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
|
|
||||||
|
if self.pp_group.is_first_rank:
|
||||||
|
self.embed_tokens = VocabParallelEmbedding(
|
||||||
|
config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
enable_tp=not is_dp_attention_enabled(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.embed_tokens = PPMissingLayer()
|
||||||
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
enable_tp=not is_dp_attention_enabled(),
|
|
||||||
)
|
|
||||||
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
||||||
self.layers = nn.ModuleList(
|
self.layers, self.start_layer, self.end_layer = make_layers(
|
||||||
[
|
config.num_hidden_layers,
|
||||||
DeepseekV2DecoderLayer(
|
lambda idx, prefix: DeepseekV2DecoderLayer(
|
||||||
config,
|
config=config,
|
||||||
layer_id,
|
layer_id=idx,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix(f"layers.{layer_id}", prefix),
|
prefix=prefix,
|
||||||
alt_stream=self.alt_stream,
|
alt_stream=self.alt_stream,
|
||||||
)
|
),
|
||||||
for layer_id in range(config.num_hidden_layers)
|
pp_rank=self.pp_group.rank_in_group,
|
||||||
]
|
pp_size=self.pp_group.world_size,
|
||||||
|
prefix=add_prefix("layers", prefix),
|
||||||
)
|
)
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
if self.pp_group.is_last_rank:
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
else:
|
||||||
|
self.norm = PPMissingLayer(return_tuple=True)
|
||||||
|
|
||||||
def get_input_embeddings(self) -> torch.Tensor:
|
def get_input_embeddings(self) -> torch.Tensor:
|
||||||
return self.embed_tokens
|
return self.embed_tokens
|
||||||
@@ -2059,8 +2070,9 @@ class DeepseekV2Model(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
) -> torch.Tensor:
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
total_num_layers = len(self.layers)
|
) -> Union[torch.Tensor, PPProxyTensors]:
|
||||||
|
total_num_layers = self.end_layer - self.start_layer
|
||||||
device = input_embeds.device if input_embeds is not None else input_ids.device
|
device = input_embeds.device if input_embeds is not None else input_ids.device
|
||||||
zero_allocator = BumpAllocator(
|
zero_allocator = BumpAllocator(
|
||||||
buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
|
buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
|
||||||
@@ -2068,44 +2080,62 @@ class DeepseekV2Model(nn.Module):
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
if input_embeds is None:
|
if self.pp_group.is_first_rank:
|
||||||
hidden_states = self.embed_tokens(input_ids)
|
if input_embeds is None:
|
||||||
|
hidden_states = self.embed_tokens(input_ids)
|
||||||
|
else:
|
||||||
|
hidden_states = input_embeds
|
||||||
|
residual = None
|
||||||
else:
|
else:
|
||||||
hidden_states = input_embeds
|
assert pp_proxy_tensors is not None
|
||||||
|
hidden_states = pp_proxy_tensors["hidden_states"]
|
||||||
|
residual = pp_proxy_tensors["residual"]
|
||||||
|
|
||||||
residual = None
|
normal_start_layer = self.start_layer
|
||||||
|
normal_end_layer = self.end_layer
|
||||||
|
if forward_batch.can_run_tbo:
|
||||||
|
if (
|
||||||
|
self.first_k_dense_replace > normal_start_layer
|
||||||
|
and self.first_k_dense_replace < normal_end_layer
|
||||||
|
):
|
||||||
|
normal_end_layer = self.first_k_dense_replace
|
||||||
|
elif self.first_k_dense_replace < normal_start_layer:
|
||||||
|
normal_end_layer = normal_start_layer = 0
|
||||||
|
|
||||||
normal_num_layers = (
|
for i in range(normal_start_layer, normal_end_layer):
|
||||||
self.first_k_dense_replace
|
|
||||||
if forward_batch.can_run_tbo
|
|
||||||
else total_num_layers
|
|
||||||
)
|
|
||||||
for i in range(normal_num_layers):
|
|
||||||
with get_global_expert_distribution_recorder().with_current_layer(i):
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
||||||
layer = self.layers[i]
|
layer = self.layers[i]
|
||||||
hidden_states, residual = layer(
|
hidden_states, residual = layer(
|
||||||
positions, hidden_states, forward_batch, residual, zero_allocator
|
positions, hidden_states, forward_batch, residual, zero_allocator
|
||||||
)
|
)
|
||||||
|
|
||||||
if normal_num_layers != total_num_layers:
|
if normal_end_layer != self.end_layer:
|
||||||
hidden_states, residual = model_forward_maybe_tbo(
|
hidden_states, residual = model_forward_maybe_tbo(
|
||||||
layers=self.layers[normal_num_layers:],
|
layers=self.layers[normal_end_layer : self.end_layer],
|
||||||
enable_tbo=True,
|
enable_tbo=True,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
forward_batch=forward_batch,
|
forward_batch=forward_batch,
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
residual=residual,
|
residual=residual,
|
||||||
input_data_scatter_mode=self.layers[
|
input_data_scatter_mode=self.layers[
|
||||||
normal_num_layers - 1
|
normal_end_layer - 1
|
||||||
].layer_scatter_modes.layer_output_mode,
|
].layer_scatter_modes.layer_output_mode,
|
||||||
zero_allocator=zero_allocator,
|
zero_allocator=zero_allocator,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not forward_batch.forward_mode.is_idle():
|
if not self.pp_group.is_last_rank:
|
||||||
if residual is None:
|
return PPProxyTensors(
|
||||||
hidden_states = self.norm(hidden_states)
|
{
|
||||||
else:
|
"hidden_states": hidden_states,
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
"residual": residual,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if not forward_batch.forward_mode.is_idle():
|
||||||
|
if residual is None:
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
else:
|
||||||
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -2132,6 +2162,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
"kv_a_proj_with_mqa",
|
"kv_a_proj_with_mqa",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self.pp_group = get_pp_group()
|
||||||
self.config = config
|
self.config = config
|
||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.quant_config = quant_config
|
self.quant_config = quant_config
|
||||||
@@ -2201,13 +2232,27 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
positions: torch.Tensor,
|
positions: torch.Tensor,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
input_embeds: torch.Tensor = None,
|
input_embeds: torch.Tensor = None,
|
||||||
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
hidden_states = self.model(
|
||||||
|
input_ids, positions, forward_batch, input_embeds, pp_proxy_tensors
|
||||||
return self.logits_processor(
|
|
||||||
input_ids, hidden_states, self.lm_head, forward_batch
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.pp_group.is_last_rank:
|
||||||
|
return self.logits_processor(
|
||||||
|
input_ids, hidden_states, self.lm_head, forward_batch
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
@property
|
||||||
|
def start_layer(self):
|
||||||
|
return self.model.start_layer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def end_layer(self):
|
||||||
|
return self.model.end_layer
|
||||||
|
|
||||||
def post_load_weights(self, is_nextn=False, weight_names=None):
|
def post_load_weights(self, is_nextn=False, weight_names=None):
|
||||||
|
|
||||||
# Perform post-processing after loading weights
|
# Perform post-processing after loading weights
|
||||||
@@ -2215,7 +2260,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
layer_ids = [self.config.num_hidden_layers]
|
layer_ids = [self.config.num_hidden_layers]
|
||||||
else:
|
else:
|
||||||
if weight_names is None:
|
if weight_names is None:
|
||||||
layer_ids = range(self.config.num_hidden_layers)
|
layer_ids = range(self.model.start_layer, self.model.end_layer)
|
||||||
else:
|
else:
|
||||||
layer_ids = set()
|
layer_ids = set()
|
||||||
for name in weight_names:
|
for name in weight_names:
|
||||||
@@ -2497,6 +2542,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
weight_names = []
|
weight_names = []
|
||||||
for name, loaded_weight in weights:
|
for name, loaded_weight in weights:
|
||||||
|
layer_id = get_layer_id(name)
|
||||||
|
if (
|
||||||
|
layer_id is not None
|
||||||
|
and hasattr(self.model, "start_layer")
|
||||||
|
and (
|
||||||
|
layer_id < self.model.start_layer
|
||||||
|
or layer_id >= self.model.end_layer
|
||||||
|
)
|
||||||
|
):
|
||||||
|
continue
|
||||||
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
||||||
name = name.replace(
|
name = name.replace(
|
||||||
"mlp.shared_experts",
|
"mlp.shared_experts",
|
||||||
@@ -2581,6 +2636,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|||||||
# Skip loading extra bias for GPTQ models.
|
# Skip loading extra bias for GPTQ models.
|
||||||
if name.endswith(".bias") and name not in params_dict:
|
if name.endswith(".bias") and name not in params_dict:
|
||||||
continue
|
continue
|
||||||
|
# Skip loading embed_tokens if not first rank in pipeline parallelism
|
||||||
|
if ".embed_tokens." in name and not self.pp_group.is_first_rank:
|
||||||
|
continue
|
||||||
|
# Skip loading norm if not last rank in pipeline parallelism
|
||||||
|
if ".norm." in name and not self.pp_group.is_last_rank:
|
||||||
|
continue
|
||||||
if fuse_qkv_a_proj and (
|
if fuse_qkv_a_proj and (
|
||||||
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
||||||
):
|
):
|
||||||
|
|||||||
133
test/srt/test_disaggregation_pp.py
Normal file
133
test/srt/test_disaggregation_pp.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
import unittest
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.few_shot_gsm8k import run_eval
|
||||||
|
from sglang.test.runners import DEFAULT_PROMPTS
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
CustomTestCase,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPDPPAccuracy(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
|
||||||
|
parsed_url = urlparse(DEFAULT_URL_FOR_TEST)
|
||||||
|
cls.base_host = parsed_url.hostname
|
||||||
|
base_port = str(parsed_url.port)
|
||||||
|
cls.lb_port = base_port
|
||||||
|
cls.prefill_port = f"{int(base_port) + 100}"
|
||||||
|
cls.decode_port = f"{int(base_port) + 200}"
|
||||||
|
cls.prefill_url = f"http://{cls.base_host}:{cls.prefill_port}"
|
||||||
|
cls.decode_url = f"http://{cls.base_host}:{cls.decode_port}"
|
||||||
|
cls.lb_url = f"http://{cls.base_host}:{cls.lb_port}"
|
||||||
|
print(f"{cls.base_host=} {cls.lb_port=} {cls.prefill_port=} {cls.decode_port=}")
|
||||||
|
|
||||||
|
# Non blocking start servers
|
||||||
|
cls.start_prefill()
|
||||||
|
cls.start_decode()
|
||||||
|
|
||||||
|
# Block until both
|
||||||
|
cls.wait_server_ready(cls.prefill_url + "/health")
|
||||||
|
cls.wait_server_ready(cls.decode_url + "/health")
|
||||||
|
|
||||||
|
lb_command = [
|
||||||
|
"python3",
|
||||||
|
"-m",
|
||||||
|
"sglang.srt.disaggregation.mini_lb",
|
||||||
|
"--prefill",
|
||||||
|
cls.prefill_url,
|
||||||
|
"--decode",
|
||||||
|
cls.decode_url,
|
||||||
|
"--host",
|
||||||
|
cls.base_host,
|
||||||
|
"--port",
|
||||||
|
cls.lb_port,
|
||||||
|
]
|
||||||
|
|
||||||
|
print("Starting load balancer:", " ".join(lb_command))
|
||||||
|
cls.process_lb = subprocess.Popen(
|
||||||
|
lb_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE
|
||||||
|
)
|
||||||
|
cls.wait_server_ready(cls.lb_url + "/health")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_prefill(cls):
|
||||||
|
prefill_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"prefill",
|
||||||
|
"--tp-size",
|
||||||
|
"2",
|
||||||
|
"--pp-size",
|
||||||
|
"2",
|
||||||
|
"--disaggregation-ib-device",
|
||||||
|
"mlx5_roce0",
|
||||||
|
"--disable-overlap-schedule",
|
||||||
|
]
|
||||||
|
cls.process_prefill = popen_launch_pd_server(
|
||||||
|
cls.model,
|
||||||
|
cls.prefill_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=prefill_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def start_decode(cls):
|
||||||
|
decode_args = [
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--disaggregation-mode",
|
||||||
|
"decode",
|
||||||
|
"--tp",
|
||||||
|
"1",
|
||||||
|
"--base-gpu-id",
|
||||||
|
"1",
|
||||||
|
"--disaggregation-ib-device",
|
||||||
|
"mlx5_roce1",
|
||||||
|
]
|
||||||
|
cls.process_decode = popen_launch_pd_server(
|
||||||
|
cls.model,
|
||||||
|
cls.decode_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=decode_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.24)
|
||||||
|
# Wait a little bit so that the memory check happens.
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -9,6 +9,8 @@ import time
|
|||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs
|
from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
@@ -62,6 +64,29 @@ class TestPPAccuracy(unittest.TestCase):
|
|||||||
# Wait a little bit so that the memory check happens.
|
# Wait a little bit so that the memory check happens.
|
||||||
time.sleep(4)
|
time.sleep(4)
|
||||||
|
|
||||||
|
def test_logprob(self):
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/generate",
|
||||||
|
json={
|
||||||
|
"text": "The capital of France is",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0,
|
||||||
|
"max_new_tokens": 16,
|
||||||
|
},
|
||||||
|
"return_logprob": True,
|
||||||
|
"top_logprobs_num": 5,
|
||||||
|
"logprob_start_len": 0,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
input_token_logprobs = response_json["meta_info"]["input_token_logprobs"]
|
||||||
|
output_token_logprobs = response_json["meta_info"]["output_token_logprobs"]
|
||||||
|
output_top_logprobs = response_json["meta_info"]["output_top_logprobs"]
|
||||||
|
|
||||||
|
assert len(input_token_logprobs) == 6
|
||||||
|
assert len(output_token_logprobs) == 16
|
||||||
|
assert len(output_top_logprobs) == 16
|
||||||
|
|
||||||
|
|
||||||
class TestQwenPPAccuracy(unittest.TestCase):
|
class TestQwenPPAccuracy(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
Reference in New Issue
Block a user