From 11383cec3c08e7912c4398838e33eafe529e1732 Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Wed, 30 Apr 2025 18:18:07 -0700 Subject: [PATCH] [PP] Add pipeline parallelism (#5724) --- python/sglang/bench_one_batch.py | 2 + python/sglang/srt/entrypoints/engine.py | 55 +-- python/sglang/srt/layers/dp_attention.py | 7 +- python/sglang/srt/layers/utils.py | 35 ++ .../srt/managers/data_parallel_controller.py | 84 +++-- python/sglang/srt/managers/schedule_batch.py | 46 ++- python/sglang/srt/managers/scheduler.py | 317 ++++++++++++++---- .../scheduler_output_processor_mixin.py | 2 +- python/sglang/srt/managers/tp_worker.py | 66 +++- .../srt/managers/tp_worker_overlap_thread.py | 12 +- python/sglang/srt/mem_cache/memory_pool.py | 106 ++++-- .../srt/model_executor/cuda_graph_runner.py | 85 ++++- .../srt/model_executor/forward_batch_info.py | 32 +- .../sglang/srt/model_executor/model_runner.py | 157 ++++++--- python/sglang/srt/models/llama.py | 122 +++++-- python/sglang/srt/models/llama4.py | 3 +- python/sglang/srt/models/llama_eagle.py | 5 +- python/sglang/srt/models/llama_eagle3.py | 5 +- python/sglang/srt/server_args.py | 53 ++- python/sglang/srt/speculative/eagle_worker.py | 5 +- python/sglang/srt/utils.py | 84 ++++- python/sglang/test/test_utils.py | 28 ++ test/srt/run_suite.py | 2 + test/srt/test_pp_single_node.py | 143 ++++++++ test/srt/test_vlm_accuracy.py | 2 + 25 files changed, 1150 insertions(+), 308 deletions(-) create mode 100644 python/sglang/srt/layers/utils.py create mode 100644 test/srt/test_pp_single_node.py diff --git a/python/sglang/bench_one_batch.py b/python/sglang/bench_one_batch.py index cdf8e9ea3..e70f3af2d 100644 --- a/python/sglang/bench_one_batch.py +++ b/python/sglang/bench_one_batch.py @@ -154,6 +154,8 @@ def load_model(server_args, port_args, tp_rank): gpu_id=tp_rank, tp_rank=tp_rank, tp_size=server_args.tp_size, + pp_rank=0, + pp_size=1, nccl_port=port_args.nccl_port, server_args=server_args, ) diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 3c1d308e8..444c64771 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -126,7 +126,6 @@ class Engine(EngineBase): server_args=server_args, port_args=port_args, ) - self.server_args = server_args self.tokenizer_manager = tokenizer_manager self.scheduler_info = scheduler_info @@ -301,7 +300,6 @@ class Engine(EngineBase): internal_states = loop.run_until_complete( self.tokenizer_manager.get_internal_state() ) - return { **dataclasses.asdict(self.tokenizer_manager.server_args), **self.scheduler_info, @@ -520,25 +518,44 @@ def _launch_subprocesses( ) scheduler_pipe_readers = [] - tp_size_per_node = server_args.tp_size // server_args.nnodes + + nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1) + tp_size_per_node = server_args.tp_size // nnodes_per_tp_group tp_rank_range = range( - tp_size_per_node * server_args.node_rank, - tp_size_per_node * (server_args.node_rank + 1), + tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group), + tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1), ) - for tp_rank in tp_rank_range: - reader, writer = mp.Pipe(duplex=False) - gpu_id = ( - server_args.base_gpu_id - + (tp_rank % tp_size_per_node) * server_args.gpu_id_step - ) - proc = mp.Process( - target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, None, writer), - ) - with memory_saver_adapter.configure_subprocess(): - proc.start() - scheduler_procs.append(proc) - scheduler_pipe_readers.append(reader) + + pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1) + pp_rank_range = range( + pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group), + pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1), + ) + + for pp_rank in pp_rank_range: + for tp_rank in tp_rank_range: + reader, writer = mp.Pipe(duplex=False) + gpu_id = ( + server_args.base_gpu_id + + ((pp_rank % pp_size_per_node) * tp_size_per_node) + + (tp_rank % tp_size_per_node) * server_args.gpu_id_step + ) + proc = mp.Process( + target=run_scheduler_process, + args=( + server_args, + port_args, + gpu_id, + tp_rank, + pp_rank, + None, + writer, + ), + ) + with memory_saver_adapter.configure_subprocess(): + proc.start() + scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) else: # Launch the data parallel controller reader, writer = mp.Pipe(duplex=False) diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py index 5f140a3df..69f94407c 100644 --- a/python/sglang/srt/layers/dp_attention.py +++ b/python/sglang/srt/layers/dp_attention.py @@ -43,6 +43,7 @@ def initialize_dp_attention( tp_rank: int, tp_size: int, dp_size: int, + pp_size: int, ): global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE @@ -53,17 +54,19 @@ def initialize_dp_attention( ) if enable_dp_attention: + local_rank = tp_rank % (tp_size // dp_size) _DP_SIZE = dp_size else: + local_rank = tp_rank _DP_SIZE = 1 tp_group = get_tp_group() _ATTN_TP_GROUP = GroupCoordinator( [ list(range(head, head + _ATTN_TP_SIZE)) - for head in range(0, tp_size, _ATTN_TP_SIZE) + for head in range(0, pp_size * tp_size, _ATTN_TP_SIZE) ], - tp_group.local_rank, + local_rank, torch.distributed.get_backend(tp_group.device_group), SYNC_TOKEN_IDS_ACROSS_TP, False, diff --git a/python/sglang/srt/layers/utils.py b/python/sglang/srt/layers/utils.py new file mode 100644 index 000000000..63e775618 --- /dev/null +++ b/python/sglang/srt/layers/utils.py @@ -0,0 +1,35 @@ +import logging +import re + +import torch + +logger = logging.getLogger(__name__) + + +def get_layer_id(weight_name): + # example weight name: model.layers.10.self_attn.qkv_proj.weight + match = re.search(r"layers\.(\d+)\.", weight_name) + if match: + return int(match.group(1)) + return None + + +class PPMissingLayer(torch.nn.Identity): + # Adapted from + # https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1 + """ + A placeholder layer for missing layers in a pipeline parallel model. + """ + + def __init__(self, *args, **kwargs): + super().__init__() + self.return_tuple = kwargs.get("return_tuple", False) + + def forward(self, *args, **kwargs): + """ + Return the first arg from args or the first value from kwargs. + + Wraps the input in a tuple if `self.return_tuple` is True. + """ + input = args[0] if args else next(iter(kwargs.values())) + return (input,) if self.return_tuple else input diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index ce921988b..e0e35a1ca 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -181,44 +181,62 @@ class DataParallelController: enable=server_args.enable_memory_saver ) - # Launch tensor parallel scheduler processes scheduler_pipe_readers = [] - tp_size_per_node = server_args.tp_size // server_args.nnodes + + nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1) + tp_size_per_node = server_args.tp_size // nnodes_per_tp_group tp_rank_range = range( - tp_size_per_node * server_args.node_rank, - tp_size_per_node * (server_args.node_rank + 1), + tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group), + tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1), ) - for tp_rank in tp_rank_range: - rank_port_args = port_args - if server_args.enable_dp_attention: - # dp attention has different sharding logic - _, _, dp_rank = compute_dp_attention_world_info( - server_args.enable_dp_attention, - tp_rank, - server_args.tp_size, - server_args.dp_size, + pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1) + pp_rank_range = range( + pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group), + pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1), + ) + + for pp_rank in pp_rank_range: + for tp_rank in tp_rank_range: + rank_port_args = port_args + + if server_args.enable_dp_attention: + # dp attention has different sharding logic + _, _, dp_rank = compute_dp_attention_world_info( + server_args.enable_dp_attention, + tp_rank, + server_args.tp_size, + server_args.dp_size, + ) + # compute zmq ports for this dp rank + rank_port_args = PortArgs.init_new(server_args, dp_rank) + # Data parallelism resues the tensor parallelism group, + # so all dp ranks should use the same nccl port. + rank_port_args.nccl_port = port_args.nccl_port + + reader, writer = mp.Pipe(duplex=False) + gpu_id = ( + server_args.base_gpu_id + + base_gpu_id + + ((pp_rank % pp_size_per_node) * tp_size_per_node) + + (tp_rank % tp_size_per_node) * server_args.gpu_id_step ) - # compute zmq ports for this dp rank - rank_port_args = PortArgs.init_new(server_args, dp_rank) - # Data parallelism resues the tensor parallelism group, - # so all dp ranks should use the same nccl port. - rank_port_args.nccl_port = port_args.nccl_port - - reader, writer = mp.Pipe(duplex=False) - gpu_id = ( - server_args.base_gpu_id - + base_gpu_id - + (tp_rank % tp_size_per_node) * server_args.gpu_id_step - ) - proc = mp.Process( - target=run_scheduler_process, - args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer), - ) - with memory_saver_adapter.configure_subprocess(): - proc.start() - self.scheduler_procs.append(proc) - scheduler_pipe_readers.append(reader) + proc = mp.Process( + target=run_scheduler_process, + args=( + server_args, + rank_port_args, + gpu_id, + tp_rank, + pp_rank, + dp_rank, + writer, + ), + ) + with memory_saver_adapter.configure_subprocess(): + proc.start() + self.scheduler_procs.append(proc) + scheduler_pipe_readers.append(reader) # Wait for model to finish loading scheduler_info = [] diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 314dbbd2e..58d5637dd 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -66,23 +66,24 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 # Put some global args for easy access global_server_args_dict = { "attention_backend": ServerArgs.attention_backend, - "sampling_backend": ServerArgs.sampling_backend, - "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32, - "torchao_config": ServerArgs.torchao_config, - "enable_nan_detection": ServerArgs.enable_nan_detection, - "enable_dp_attention": ServerArgs.enable_dp_attention, - "enable_ep_moe": ServerArgs.enable_ep_moe, - "enable_deepep_moe": ServerArgs.enable_deepep_moe, + "chunked_prefill_size": ServerArgs.chunked_prefill_size, "deepep_mode": ServerArgs.deepep_mode, "device": ServerArgs.device, - "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, - "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, - "disable_radix_cache": ServerArgs.disable_radix_cache, - "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, - "moe_dense_tp_size": ServerArgs.moe_dense_tp_size, - "chunked_prefill_size": ServerArgs.chunked_prefill_size, - "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache, + "disable_radix_cache": ServerArgs.disable_radix_cache, + "enable_deepep_moe": ServerArgs.enable_deepep_moe, + "enable_dp_attention": ServerArgs.enable_dp_attention, + "enable_ep_moe": ServerArgs.enable_ep_moe, + "enable_nan_detection": ServerArgs.enable_nan_detection, + "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, + "max_micro_batch_size": ServerArgs.max_micro_batch_size, + "moe_dense_tp_size": ServerArgs.moe_dense_tp_size, + "n_share_experts_fusion": ServerArgs.n_share_experts_fusion, + "sampling_backend": ServerArgs.sampling_backend, + "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, + "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, + "torchao_config": ServerArgs.torchao_config, + "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32, } logger = logging.getLogger(__name__) @@ -728,6 +729,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Events launch_done: Optional[threading.Event] = None + # For chunked prefill in PP + chunked_req: Optional[Req] = None + # Sampling info sampling_info: SamplingBatchInfo = None next_batch_sampling_info: SamplingBatchInfo = None @@ -761,7 +765,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # For extend and mixed chunekd prefill prefix_lens: List[int] = None extend_lens: List[int] = None - extend_num_tokens: int = None + extend_num_tokens: Optional[int] = None decoding_reqs: List[Req] = None extend_logprob_start_lens: List[int] = None # It comes empty list if logprob is not required. @@ -803,6 +807,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): enable_overlap: bool, spec_algorithm: SpeculativeAlgorithm, enable_custom_logit_processor: bool, + chunked_req: Optional[Req] = None, ): return_logprob = any(req.return_logprob for req in reqs) @@ -820,6 +825,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): spec_algorithm=spec_algorithm, enable_custom_logit_processor=enable_custom_logit_processor, return_hidden_states=any(req.return_hidden_states for req in reqs), + chunked_req=chunked_req, ) def batch_size(self): @@ -1236,7 +1242,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def retract_decode(self, server_args: ServerArgs): """Retract the decoding requests when there is not enough memory.""" - sorted_indices = [i for i in range(len(self.reqs))] + sorted_indices = list(range(len(self.reqs))) # TODO(lsyin): improve retraction policy for radix cache # For spec decoding, filter_batch API can only filter @@ -1413,15 +1419,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): def filter_batch( self, - chunked_req_to_exclude: Optional[Req] = None, + chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None, keep_indices: Optional[List[int]] = None, ): if keep_indices is None: + if isinstance(chunked_req_to_exclude, Req): + chunked_req_to_exclude = [chunked_req_to_exclude] + elif chunked_req_to_exclude is None: + chunked_req_to_exclude = [] keep_indices = [ i for i in range(len(self.reqs)) if not self.reqs[i].finished() - and self.reqs[i] is not chunked_req_to_exclude + and not self.reqs[i] in chunked_req_to_exclude ] if keep_indices is None or len(keep_indices) == 0: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ffcbd4667..8891115c1 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -51,6 +51,7 @@ from sglang.srt.disaggregation.utils import ( ReqToMetadataIdxAllocator, TransferBackend, ) +from sglang.srt.distributed import get_pp_group, get_world_group from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput @@ -114,7 +115,11 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache from sglang.srt.mem_cache.hiradix_cache import HiRadixCache from sglang.srt.mem_cache.radix_cache import RadixCache from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats -from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.model_executor.forward_batch_info import ( + ForwardBatch, + ForwardMode, + PPProxyTensors, +) from sglang.srt.reasoning_parser import ReasoningParser from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm @@ -127,6 +132,7 @@ from sglang.srt.utils import ( get_bool_env_var, get_zmq_socket, kill_itself_when_parent_died, + point_to_point_pyobj, pyspy_dump_schedulers, set_gpu_proc_affinity, set_random_seed, @@ -145,8 +151,9 @@ RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME") @dataclass class GenerationBatchResult: - logits_output: LogitsProcessorOutput - next_token_ids: List[int] + logits_output: Optional[LogitsProcessorOutput] + pp_hidden_states_proxy_tensors: Optional[torch.Tensor] + next_token_ids: Optional[List[int]] extend_input_len_per_req: List[int] extend_logprob_start_len_per_req: List[int] bid: int @@ -171,12 +178,16 @@ class Scheduler( port_args: PortArgs, gpu_id: int, tp_rank: int, + pp_rank: int, dp_rank: Optional[int], ): # Parse args self.server_args = server_args self.tp_rank = tp_rank + self.pp_rank = pp_rank self.tp_size = server_args.tp_size + self.pp_size = server_args.pp_size + self.dp_size = server_args.dp_size self.schedule_policy = server_args.schedule_policy self.lora_paths = server_args.lora_paths self.max_loras_per_batch = server_args.max_loras_per_batch @@ -192,7 +203,6 @@ class Scheduler( self.page_size = server_args.page_size # Distributed rank info - self.dp_size = server_args.dp_size self.attn_tp_rank, self.attn_tp_size, self.dp_rank = ( compute_dp_attention_world_info( server_args.enable_dp_attention, @@ -204,7 +214,7 @@ class Scheduler( # Init inter-process communication context = zmq.Context(2) - if self.attn_tp_rank == 0: + if self.pp_rank == 0 and self.attn_tp_rank == 0: self.recv_from_tokenizer = get_zmq_socket( context, zmq.PULL, port_args.scheduler_input_ipc_name, False ) @@ -259,6 +269,7 @@ class Scheduler( server_args=server_args, gpu_id=gpu_id, tp_rank=tp_rank, + pp_rank=pp_rank, dp_rank=dp_rank, nccl_port=port_args.nccl_port, ) @@ -292,8 +303,18 @@ class Scheduler( _, _, ) = self.tp_worker.get_worker_info() - self.tp_cpu_group = self.tp_worker.get_tp_cpu_group() + if global_server_args_dict["max_micro_batch_size"] is None: + global_server_args_dict["max_micro_batch_size"] = max( + self.max_running_requests // server_args.pp_size, 1 + ) + + self.tp_group = self.tp_worker.get_tp_group() + self.tp_cpu_group = self.tp_group.cpu_group + self.attn_tp_group = self.tp_worker.get_attention_tp_group() self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group() + self.pp_group = get_pp_group() + self.world_group = get_world_group() + self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() global_server_args_dict.update(worker_global_server_args_dict) set_random_seed(self.random_seed) @@ -673,26 +694,141 @@ class Scheduler( self.last_batch = batch + @DynamicGradMode() + def event_loop_pp(self): + """A non-overlap scheduler loop for pipeline parallelism.""" + mbs = [None] * self.pp_size + last_mbs = [None] * self.pp_size + self.running_mbs = [ + ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size) + ] + bids = [None] * self.pp_size + pp_outputs: Optional[PPProxyTensors] = None + while True: + server_is_idle = True + for mb_id in range(self.pp_size): + self.running_batch = self.running_mbs[mb_id] + self.last_batch = last_mbs[mb_id] + + recv_reqs = self.recv_requests() + self.process_input_requests(recv_reqs) + mbs[mb_id] = self.get_next_batch_to_run() + self.running_mbs[mb_id] = self.running_batch + + self.cur_batch = mbs[mb_id] + if self.cur_batch: + server_is_idle = False + result = self.run_batch(self.cur_batch) + + # send the outputs to the next step + if self.pp_group.is_last_rank: + if self.cur_batch: + next_token_ids, bids[mb_id] = ( + result.next_token_ids, + result.bid, + ) + pp_outputs = PPProxyTensors( + { + "next_token_ids": next_token_ids, + } + ) + # send the output from the last round to let the next stage worker run post processing + self.pp_group.send_tensor_dict( + pp_outputs.tensors, + all_gather_group=self.attn_tp_group, + ) + + # receive outputs and post-process (filter finished reqs) the coming microbatch + next_mb_id = (mb_id + 1) % self.pp_size + next_pp_outputs = None + if mbs[next_mb_id] is not None: + next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors( + self.pp_group.recv_tensor_dict( + all_gather_group=self.attn_tp_group + ) + ) + mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"] + output_result = GenerationBatchResult( + logits_output=None, + pp_hidden_states_proxy_tensors=None, + next_token_ids=next_pp_outputs["next_token_ids"], + extend_input_len_per_req=None, + extend_logprob_start_len_per_req=None, + bid=bids[next_mb_id], + ) + self.process_batch_result(mbs[next_mb_id], output_result) + last_mbs[next_mb_id] = mbs[next_mb_id] + + # carry the outputs to the next stage + if not self.pp_group.is_last_rank: + if self.cur_batch: + bids[mb_id] = result.bid + if pp_outputs: + # send the outputs from the last round to let the next stage worker run post processing + self.pp_group.send_tensor_dict( + pp_outputs.tensors, + all_gather_group=self.attn_tp_group, + ) + + if not self.pp_group.is_last_rank: + # send out reqs to the next stage + dp_offset = self.dp_rank * self.attn_tp_size + if self.attn_tp_rank == 0: + point_to_point_pyobj( + recv_reqs, + self.pp_rank * self.tp_size + dp_offset, + self.world_group.cpu_group, + self.pp_rank * self.tp_size + dp_offset, + (self.pp_rank + 1) * self.tp_size + dp_offset, + ) + + # send out proxy tensors to the next stage + if self.cur_batch: + self.pp_group.send_tensor_dict( + result.pp_hidden_states_proxy_tensors, + all_gather_group=self.attn_tp_group, + ) + + pp_outputs = next_pp_outputs + + # When the server is idle, self-check and re-init some states + if server_is_idle: + self.check_memory() + self.new_token_ratio = self.init_new_token_ratio + def recv_requests(self) -> List[Req]: """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" - if self.attn_tp_rank == 0: - recv_reqs = [] + if self.pp_rank == 0: + if self.attn_tp_rank == 0: + recv_reqs = [] - while True: - try: - recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) - except zmq.ZMQError: - break - recv_reqs.append(recv_req) + while True: + try: + recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + recv_reqs.append(recv_req) - while True: - try: - recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK) - except zmq.ZMQError: - break - recv_reqs.append(recv_rpc) + while True: + try: + recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + recv_reqs.append(recv_rpc) + else: + recv_reqs = None else: - recv_reqs = None + if self.attn_tp_rank == 0: + dp_offset = self.dp_rank * self.attn_tp_size + recv_reqs = point_to_point_pyobj( + [], + self.pp_rank * self.tp_size + dp_offset, + self.world_group.cpu_group, + (self.pp_rank - 1) * self.tp_size + dp_offset, + self.pp_rank * self.tp_size + dp_offset, + ) + else: + recv_reqs = None if self.server_args.enable_dp_attention: if self.attn_tp_rank == 0: @@ -715,20 +851,27 @@ class Scheduler( control_reqs = None if self.attn_tp_size != 1: - attn_tp_rank_0 = self.dp_rank * self.attn_tp_size work_reqs = broadcast_pyobj( work_reqs, - self.attn_tp_rank, + self.attn_tp_group.rank, self.attn_tp_cpu_group, - src=attn_tp_rank_0, + src=self.attn_tp_group.ranks[0], ) if self.tp_size != 1: control_reqs = broadcast_pyobj( - control_reqs, self.tp_rank, self.tp_cpu_group + control_reqs, + self.tp_group.rank, + self.tp_cpu_group, + src=self.tp_group.ranks[0], ) recv_reqs = work_reqs + control_reqs elif self.tp_size != 1: - recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) + recv_reqs = broadcast_pyobj( + recv_reqs, + self.tp_group.rank, + self.tp_cpu_group, + src=self.tp_group.ranks[0], + ) return recv_reqs def process_input_requests(self, recv_reqs: List): @@ -1026,12 +1169,14 @@ class Scheduler( self.metrics_collector.log_stats(self.stats) - def log_decode_stats(self): + def log_decode_stats(self, running_batch=None): + batch = running_batch or self.running_batch + gap_latency = time.time() - self.last_decode_stats_tic self.last_decode_stats_tic = time.time() self.last_gen_throughput = self.num_generated_tokens / gap_latency self.num_generated_tokens = 0 - num_running_reqs = len(self.running_batch.reqs) + num_running_reqs = len(batch.reqs) num_used = self.max_total_num_tokens - ( self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size() @@ -1131,19 +1276,25 @@ class Scheduler( def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: # Merge the prefill batch into the running batch + chunked_req_to_exclude = set() + if self.chunked_req: + # Move the chunked request out of the batch so that we can merge + # only finished requests to running_batch. + chunked_req_to_exclude.add(self.chunked_req) + self.tree_cache.cache_unfinished_req(self.chunked_req) + # chunked request keeps its rid but will get a new req_pool_idx + self.req_to_token_pool.free(self.chunked_req.req_pool_idx) if self.last_batch and self.last_batch.forward_mode.is_extend(): - if self.chunked_req: - # Move the chunked request out of the batch so that we can merge - # only finished requests to running_batch. - self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req) - self.tree_cache.cache_unfinished_req(self.chunked_req) - # chunked request keeps its rid but will get a new req_pool_idx - self.req_to_token_pool.free(self.chunked_req.req_pool_idx) - self.running_batch.batch_is_full = False + if self.last_batch.chunked_req is not None: + # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req. + # We need to discard it. + chunked_req_to_exclude.add(self.last_batch.chunked_req) # Filter batch last_bs = self.last_batch.batch_size() - self.last_batch.filter_batch() + self.last_batch.filter_batch( + chunked_req_to_exclude=list(chunked_req_to_exclude) + ) if self.last_batch.batch_size() < last_bs: self.running_batch.batch_is_full = False @@ -1173,6 +1324,12 @@ class Scheduler( return ret + def get_num_allocatable_reqs(self, running_bs): + res = global_server_args_dict["max_micro_batch_size"] - running_bs + if self.pp_size > 1: + res = min(res, self.req_to_token_pool.available_size()) + return res + def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Check if the grammar is ready in the grammar queue if self.grammar_queue: @@ -1185,7 +1342,12 @@ class Scheduler( return None running_bs = len(self.running_batch.reqs) - if running_bs >= self.max_running_requests: + # Igore the check if self.chunked_req is not None. + # In the non-PP case, when self.chunked_req is not None, num_allocatable_reqs should always be greater than 0, + # as the space for the chunked request has just been released. + # In PP case, a chunked req can start in one microbatch and end in another microbatch, so the max_running_requests per microbatch should not be strict. + # Instead, we should always allow chunked request to be added, otherwise, there will be a memory leak. + if self.get_num_allocatable_reqs(running_bs) <= 0 and not self.chunked_req: self.running_batch.batch_is_full = True return None @@ -1229,7 +1391,7 @@ class Scheduler( self.running_batch.batch_is_full = True break - if running_bs + len(adder.can_run_list) >= self.max_running_requests: + if len(adder.can_run_list) >= self.get_num_allocatable_reqs(running_bs): self.running_batch.batch_is_full = True break @@ -1241,16 +1403,14 @@ class Scheduler( res = adder.add_one_req( req, self.chunked_req, self.enable_hierarchical_cache ) + if res != AddReqResult.CONTINUE: if res == AddReqResult.NO_TOKEN: if self.enable_hierarchical_cache: # Set batch_is_full after making sure there are requests that can be served self.running_batch.batch_is_full = len( adder.can_run_list - ) > 0 or ( - self.running_batch is not None - and not self.running_batch.is_empty() - ) + ) > 0 or (not self.running_batch.is_empty()) else: self.running_batch.batch_is_full = True break @@ -1293,6 +1453,7 @@ class Scheduler( self.enable_overlap, self.spec_algorithm, self.server_args.enable_custom_logit_processor, + chunked_req=self.chunked_req, ) new_batch.prepare_for_extend() @@ -1370,9 +1531,14 @@ class Scheduler( if self.is_generation: if self.spec_algorithm.is_none(): model_worker_batch = batch.get_model_worker_batch() - logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - model_worker_batch - ) + if self.pp_group.is_last_rank: + logits_output, next_token_ids = ( + self.tp_worker.forward_batch_generation(model_worker_batch) + ) + else: + pp_hidden_states_proxy_tensors, _ = ( + self.tp_worker.forward_batch_generation(model_worker_batch) + ) bid = model_worker_batch.bid else: ( @@ -1386,7 +1552,9 @@ class Scheduler( ) self.spec_num_total_forward_ct += batch.batch_size() self.num_generated_tokens += num_accepted_tokens - batch.output_ids = next_token_ids + + if self.pp_group.is_last_rank: + batch.output_ids = next_token_ids # These 2 values are needed for processing the output, but the values can be # modified by overlap schedule. So we have to copy them here so that @@ -1401,8 +1569,13 @@ class Scheduler( extend_logprob_start_len_per_req = None ret = GenerationBatchResult( - logits_output=logits_output, - next_token_ids=next_token_ids, + logits_output=logits_output if self.pp_group.is_last_rank else None, + pp_hidden_states_proxy_tensors=( + pp_hidden_states_proxy_tensors + if not self.pp_group.is_last_rank + else None + ), + next_token_ids=next_token_ids if self.pp_group.is_last_rank else None, extend_input_len_per_req=extend_input_len_per_req, extend_logprob_start_len_per_req=extend_logprob_start_len_per_req, bid=bid, @@ -1553,6 +1726,7 @@ class Scheduler( def move_ready_grammar_requests(self): """Move requests whose grammar objects are ready from grammar_queue to waiting_queue.""" + num_ready_reqs = 0 for req in self.grammar_queue: try: @@ -1619,7 +1793,11 @@ class Scheduler( def flush_cache(self): """Flush the memory pool and cache.""" - if len(self.waiting_queue) == 0 and self.running_batch.is_empty(): + if ( + len(self.waiting_queue) == 0 + and self.running_batch.is_empty() + and (self.pp_size == 1 or all(x.is_empty() for x in self.running_mbs)) + ): self.cur_batch = None self.last_batch = None self.tree_cache.reset() @@ -1657,7 +1835,6 @@ class Scheduler( ret["avg_spec_accept_length"] = ( self.cum_spec_accept_length / self.cum_spec_accept_count ) - if RECORD_STEP_TIME: ret["step_time_dict"] = self.step_time_dict return GetInternalStateReqOutput( @@ -1668,6 +1845,7 @@ class Scheduler( server_args_dict = recv_req.server_args args_allow_update = set( [ + "max_micro_batch_size", "speculative_accept_threshold_single", "speculative_accept_threshold_acc", ] @@ -1678,6 +1856,14 @@ class Scheduler( logging.warning(f"Updating {k} is not supported.") if_success = False break + elif k == "max_micro_batch_size" and ( + v > self.max_running_requests // self.pp_size or v < 1 + ): + logging.warning( + f"Updating {k} to {v} is rejected because it is out of the valid range [1, {self.max_running_requests // self.pp_size}]." + ) + if_success = False + break if if_success: if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0: avg_spec_accept_length = ( @@ -1959,6 +2145,16 @@ class Scheduler( else: del self.sessions[session_id] + def get_print_prefix(self): + prefix = "" + if self.dp_rank is not None: + prefix += f" DP{self.dp_rank}" + if self.server_args.tp_size > 1: + prefix += f" TP{self.tp_rank}" + if self.pp_size > 1: + prefix += f" PP{self.pp_rank}" + return prefix + def is_health_check_generate_req(recv_req): return getattr(recv_req, "rid", "").startswith("HEALTH_CHECK") @@ -1983,14 +2179,18 @@ def run_scheduler_process( port_args: PortArgs, gpu_id: int, tp_rank: int, + pp_rank: int, dp_rank: Optional[int], pipe_writer, ): # Generate the prefix - if dp_rank is None: - prefix = f" TP{tp_rank}" - else: - prefix = f" DP{dp_rank} TP{tp_rank}" + prefix = "" + if dp_rank is not None: + prefix += f" DP{dp_rank}" + if server_args.tp_size > 1: + prefix += f" TP{tp_rank}" + if server_args.pp_size > 1: + prefix += f" PP{pp_rank}" # Config the process kill_itself_when_parent_died() @@ -2012,7 +2212,7 @@ def run_scheduler_process( # Create a scheduler and run the event loop try: - scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) + scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, pp_rank, dp_rank) pipe_writer.send( { "status": "ready", @@ -2023,7 +2223,9 @@ def run_scheduler_process( disaggregation_mode: DisaggregationMode = scheduler.disaggregation_mode if disaggregation_mode == DisaggregationMode.NULL: - if scheduler.enable_overlap: + if server_args.pp_size > 1: + scheduler.event_loop_pp() + elif scheduler.enable_overlap: scheduler.event_loop_overlap() else: scheduler.event_loop_normal() @@ -2032,6 +2234,7 @@ def run_scheduler_process( scheduler.event_loop_overlap_disagg_prefill() else: scheduler.event_loop_normal_disagg_prefill() + elif disaggregation_mode == DisaggregationMode.DECODE: if scheduler.enable_overlap: scheduler.event_loop_overlap_disagg_decode() diff --git a/python/sglang/srt/managers/scheduler_output_processor_mixin.py b/python/sglang/srt/managers/scheduler_output_processor_mixin.py index ce570b75a..c87c1b264 100644 --- a/python/sglang/srt/managers/scheduler_output_processor_mixin.py +++ b/python/sglang/srt/managers/scheduler_output_processor_mixin.py @@ -278,7 +278,7 @@ class SchedulerOutputProcessorMixin: self.attn_tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0 ): - self.log_decode_stats() + self.log_decode_stats(running_batch=batch) def add_input_logprob_return_values( self: Scheduler, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index c7666ffc6..0dd2009ea 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -15,11 +15,12 @@ import logging import threading -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( @@ -31,7 +32,7 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.server_args import ServerArgs from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed @@ -47,6 +48,7 @@ class TpModelWorker: server_args: ServerArgs, gpu_id: int, tp_rank: int, + pp_rank: int, dp_rank: Optional[int], nccl_port: int, is_draft_worker: bool = False, @@ -54,7 +56,9 @@ class TpModelWorker: token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None, ): # Parse args + self.tp_size = server_args.tp_size self.tp_rank = tp_rank + self.pp_rank = pp_rank # Init model and tokenizer self.model_config = ModelConfig( @@ -73,12 +77,15 @@ class TpModelWorker: quantization=server_args.quantization, is_draft_model=is_draft_worker, ) + self.model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, gpu_id=gpu_id, tp_rank=tp_rank, tp_size=server_args.tp_size, + pp_rank=pp_rank, + pp_size=server_args.pp_size, nccl_port=nccl_port, server_args=server_args, is_draft_worker=is_draft_worker, @@ -105,6 +112,10 @@ class TpModelWorker: ) self.device = self.model_runner.device + # Init nccl groups + self.pp_group = get_pp_group() + self.world_group = get_world_group() + # Profile number of tokens self.max_total_num_tokens = self.model_runner.max_total_num_tokens self.max_prefill_tokens = server_args.max_prefill_tokens @@ -130,8 +141,9 @@ class TpModelWorker: # Sync random seed across TP workers self.random_seed = broadcast_pyobj( [server_args.random_seed], - self.tp_rank, - self.model_runner.tp_group.cpu_group, + self.tp_size * self.pp_rank + tp_rank, + self.world_group.cpu_group, + src=self.world_group.ranks[0], )[0] set_random_seed(self.random_seed) @@ -156,11 +168,14 @@ class TpModelWorker: def get_pad_input_ids_func(self): return getattr(self.model_runner.model, "pad_input_ids", None) - def get_tp_cpu_group(self): - return self.model_runner.tp_group.cpu_group + def get_tp_group(self): + return self.model_runner.tp_group + + def get_attention_tp_group(self): + return self.model_runner.attention_tp_group def get_attention_tp_cpu_group(self): - return self.model_runner.attention_tp_group.cpu_group + return getattr(self.model_runner.attention_tp_group, "cpu_group", None) def get_memory_pool(self): return ( @@ -172,19 +187,38 @@ class TpModelWorker: self, model_worker_batch: ModelWorkerBatch, skip_sample: bool = False, - ) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]: + ) -> Tuple[Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor]]: forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - logits_output = self.model_runner.forward(forward_batch) - if model_worker_batch.launch_done is not None: - model_worker_batch.launch_done.set() + pp_proxy_tensors = None + if not self.pp_group.is_first_rank: + pp_proxy_tensors = PPProxyTensors( + self.pp_group.recv_tensor_dict( + all_gather_group=self.get_attention_tp_group() + ) + ) - if skip_sample: - next_token_ids = None + if self.pp_group.is_last_rank: + logits_output = self.model_runner.forward( + forward_batch, pp_proxy_tensors=pp_proxy_tensors + ) + if model_worker_batch.launch_done is not None: + model_worker_batch.launch_done.set() + + if skip_sample: + next_token_ids = None + else: + next_token_ids = self.model_runner.sample( + logits_output, model_worker_batch + ) + + return logits_output, next_token_ids else: - next_token_ids = self.model_runner.sample(logits_output, model_worker_batch) - - return logits_output, next_token_ids + pp_proxy_tensors = self.model_runner.forward( + forward_batch, + pp_proxy_tensors=pp_proxy_tensors, + ) + return pp_proxy_tensors.tensors, None def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 8aa7f3346..8bfcfe02f 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -56,11 +56,14 @@ class TpModelWorkerClient: server_args: ServerArgs, gpu_id: int, tp_rank: int, + pp_rank: int, dp_rank: Optional[int], nccl_port: int, ): # Load the model - self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port) + self.worker = TpModelWorker( + server_args, gpu_id, tp_rank, pp_rank, dp_rank, nccl_port + ) self.max_running_requests = self.worker.max_running_requests self.device = self.worker.device self.gpu_id = gpu_id @@ -91,8 +94,11 @@ class TpModelWorkerClient: def get_pad_input_ids_func(self): return self.worker.get_pad_input_ids_func() - def get_tp_cpu_group(self): - return self.worker.get_tp_cpu_group() + def get_tp_group(self): + return self.worker.get_tp_group() + + def get_attention_tp_group(self): + return self.worker.get_attention_tp_group() def get_attention_tp_cpu_group(self): return self.worker.get_attention_tp_cpu_group() diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index b28ad55f9..f7eef2120 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -214,6 +214,8 @@ class MHATokenToKVPool(KVCache): layer_num: int, device: str, enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, ): self.size = size self.page_size = page_size @@ -232,6 +234,8 @@ class MHATokenToKVPool(KVCache): self.head_dim = head_dim self.layer_num = layer_num self._create_buffers() + self.start_layer = start_layer or 0 + self.end_layer = end_layer or layer_num - 1 self.layer_transfer_counter = None self.capture_mode = False @@ -281,6 +285,8 @@ class MHATokenToKVPool(KVCache): # for disagg def get_contiguous_buf_infos(self): + # layer_num x [seq_len, head_num, head_dim] + # layer_num x [page_num, page_size, head_num, head_dim] kv_data_ptrs = [ self.get_key_buffer(i).data_ptr() for i in range(self.layer_num) ] + [self.get_value_buffer(i).data_ptr() for i in range(self.layer_num)] @@ -320,24 +326,24 @@ class MHATokenToKVPool(KVCache): # transfer prepared data from host to device flat_data = flat_data.to(device=self.device, non_blocking=False) k_data, v_data = flat_data[0], flat_data[1] - self.k_buffer[layer_id][indices] = k_data - self.v_buffer[layer_id][indices] = v_data + self.k_buffer[layer_id - self.start_layer][indices] = k_data + self.v_buffer[layer_id - self.start_layer][indices] = v_data def get_key_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: - self.layer_transfer_counter.wait_until(layer_id) + self.layer_transfer_counter.wait_until(layer_id - self.start_layer) if self.store_dtype != self.dtype: - return self.k_buffer[layer_id].view(self.dtype) - return self.k_buffer[layer_id] + return self.k_buffer[layer_id - self.start_layer].view(self.dtype) + return self.k_buffer[layer_id - self.start_layer] def get_value_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: - self.layer_transfer_counter.wait_until(layer_id) + self.layer_transfer_counter.wait_until(layer_id - self.start_layer) if self.store_dtype != self.dtype: - return self.v_buffer[layer_id].view(self.dtype) - return self.v_buffer[layer_id] + return self.v_buffer[layer_id - self.start_layer].view(self.dtype) + return self.v_buffer[layer_id - self.start_layer] def get_kv_buffer(self, layer_id: int): return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) @@ -369,12 +375,12 @@ class MHATokenToKVPool(KVCache): current_stream = self.device_module.current_stream() self.alt_stream.wait_stream(current_stream) with self.device_module.stream(self.alt_stream): - self.k_buffer[layer_id][loc] = cache_k - self.v_buffer[layer_id][loc] = cache_v + self.k_buffer[layer_id - self.start_layer][loc] = cache_k + self.v_buffer[layer_id - self.start_layer][loc] = cache_v current_stream.wait_stream(self.alt_stream) else: - self.k_buffer[layer_id][loc] = cache_k - self.v_buffer[layer_id][loc] = cache_v + self.k_buffer[layer_id - self.start_layer][loc] = cache_k + self.v_buffer[layer_id - self.start_layer][loc] = cache_v @torch.compile @@ -484,6 +490,8 @@ class MLATokenToKVPool(KVCache): layer_num: int, device: str, enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, ): self.size = size self.page_size = page_size @@ -497,6 +505,8 @@ class MLATokenToKVPool(KVCache): self.kv_lora_rank = kv_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.layer_num = layer_num + self.start_layer = start_layer or 0 + self.end_layer = end_layer or layer_num - 1 memory_saver_adapter = TorchMemorySaverAdapter.create( enable=enable_memory_saver @@ -540,19 +550,21 @@ class MLATokenToKVPool(KVCache): def get_key_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: - self.layer_transfer_counter.wait_until(layer_id) + self.layer_transfer_counter.wait_until(layer_id - self.start_layer) if self.store_dtype != self.dtype: - return self.kv_buffer[layer_id].view(self.dtype) - return self.kv_buffer[layer_id] + return self.kv_buffer[layer_id - self.start_layer].view(self.dtype) + return self.kv_buffer[layer_id - self.start_layer] def get_value_buffer(self, layer_id: int): if self.layer_transfer_counter is not None: - self.layer_transfer_counter.wait_until(layer_id) + self.layer_transfer_counter.wait_until(layer_id - self.start_layer) if self.store_dtype != self.dtype: - return self.kv_buffer[layer_id][..., : self.kv_lora_rank].view(self.dtype) - return self.kv_buffer[layer_id][..., : self.kv_lora_rank] + return self.kv_buffer[layer_id - self.start_layer][ + ..., : self.kv_lora_rank + ].view(self.dtype) + return self.kv_buffer[layer_id - self.start_layer][..., : self.kv_lora_rank] def get_kv_buffer(self, layer_id: int): return self.get_key_buffer(layer_id), self.get_value_buffer(layer_id) @@ -568,9 +580,11 @@ class MLATokenToKVPool(KVCache): if cache_k.dtype != self.dtype: cache_k = cache_k.to(self.dtype) if self.store_dtype != self.dtype: - self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype) + self.kv_buffer[layer_id - self.start_layer][loc] = cache_k.view( + self.store_dtype + ) else: - self.kv_buffer[layer_id][loc] = cache_k + self.kv_buffer[layer_id - self.start_layer][loc] = cache_k def set_mla_kv_buffer( self, @@ -605,7 +619,7 @@ class MLATokenToKVPool(KVCache): def transfer_per_layer(self, indices, flat_data, layer_id): # transfer prepared data from host to device flat_data = flat_data.to(device=self.device, non_blocking=False) - self.kv_buffer[layer_id][indices] = flat_data + self.kv_buffer[layer_id - self.start_layer][indices] = flat_data class DoubleSparseTokenToKVPool(KVCache): @@ -620,6 +634,8 @@ class DoubleSparseTokenToKVPool(KVCache): device: str, heavy_channel_num: int, enable_memory_saver: bool, + start_layer: Optional[int] = None, + end_layer: Optional[int] = None, ): self.size = size self.page_size = page_size @@ -657,17 +673,23 @@ class DoubleSparseTokenToKVPool(KVCache): for _ in range(layer_num) ] + self.start_layer = start_layer or 0 + self.end_layer = end_layer or layer_num - 1 + def get_key_buffer(self, layer_id: int): - return self.k_buffer[layer_id] + return self.k_buffer[layer_id - self.start_layer] def get_value_buffer(self, layer_id: int): - return self.v_buffer[layer_id] + return self.v_buffer[layer_id - self.start_layer] def get_label_buffer(self, layer_id: int): - return self.label_buffer[layer_id] + return self.label_buffer[layer_id - self.start_layer] def get_kv_buffer(self, layer_id: int): - return self.k_buffer[layer_id], self.v_buffer[layer_id] + return ( + self.k_buffer[layer_id - self.start_layer], + self.v_buffer[layer_id - self.start_layer], + ) def set_kv_buffer( self, @@ -679,9 +701,9 @@ class DoubleSparseTokenToKVPool(KVCache): ): # NOTE(Andy): ignore the dtype check layer_id = layer.layer_id - self.k_buffer[layer_id][loc] = cache_k - self.v_buffer[layer_id][loc] = cache_v - self.label_buffer[layer_id][loc] = cache_label + self.k_buffer[layer_id - self.start_layer][loc] = cache_k + self.v_buffer[layer_id - self.start_layer][loc] = cache_v + self.label_buffer[layer_id - self.start_layer][loc] = cache_label def get_flat_data(self, indices): pass @@ -930,7 +952,7 @@ class MHATokenToKVPoolHost(HostKVCache): return self.kv_buffer[:, :, indices] def get_flat_data_by_layer(self, indices, layer_id): - return self.kv_buffer[:, layer_id, indices] + return self.kv_buffer[:, layer_id - self.start_layer, indices] def assign_flat_data(self, indices, flat_data): self.kv_buffer[:, :, indices] = flat_data @@ -955,12 +977,20 @@ class MHATokenToKVPoolHost(HostKVCache): for i in range(len(device_indices_cpu)): h_index = host_indices[i * self.page_size] d_index = device_indices_cpu[i] - device_pool.k_buffer[layer_id][d_index : d_index + self.page_size].copy_( - self.kv_buffer[0, layer_id, h_index : h_index + self.page_size], + device_pool.k_buffer[layer_id - self.start_layer][ + d_index : d_index + self.page_size + ].copy_( + self.kv_buffer[ + 0, layer_id - self.start_layer, h_index : h_index + self.page_size + ], non_blocking=True, ) - device_pool.v_buffer[layer_id][d_index : d_index + self.page_size].copy_( - self.kv_buffer[1, layer_id, h_index : h_index + self.page_size], + device_pool.v_buffer[layer_id - self.start_layer][ + d_index : d_index + self.page_size + ].copy_( + self.kv_buffer[ + 1, layer_id - self.start_layer, h_index : h_index + self.page_size + ], non_blocking=True, ) @@ -1015,7 +1045,7 @@ class MLATokenToKVPoolHost(HostKVCache): return self.kv_buffer[:, indices] def get_flat_data_by_layer(self, indices, layer_id): - return self.kv_buffer[layer_id, indices] + return self.kv_buffer[layer_id - self.start_layer, indices] def assign_flat_data(self, indices, flat_data): self.kv_buffer[:, indices] = flat_data @@ -1036,7 +1066,11 @@ class MLATokenToKVPoolHost(HostKVCache): for i in range(len(device_indices_cpu)): h_index = host_indices[i * self.page_size] d_index = device_indices_cpu[i] - device_pool.kv_buffer[layer_id][d_index : d_index + self.page_size].copy_( - self.kv_buffer[layer_id, h_index : h_index + self.page_size], + device_pool.kv_buffer[layer_id - self.start_layer][ + d_index : d_index + self.page_size + ].copy_( + self.kv_buffer[ + layer_id - self.start_layer, h_index : h_index + self.page_size + ], non_blocking=True, ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index b36f15a86..b354215c1 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -16,6 +16,7 @@ from __future__ import annotations import bisect +import inspect import os from contextlib import contextmanager from typing import TYPE_CHECKING, Callable @@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, ForwardMode, + PPProxyTensors, ) from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.utils import ( get_available_gpu_memory, get_device_memory_capacity, is_hip, + rank0_log, ) if TYPE_CHECKING: @@ -188,10 +191,11 @@ class CudaGraphRunner: self.speculative_algorithm = model_runner.server_args.speculative_algorithm self.tp_size = model_runner.server_args.tp_size self.dp_size = model_runner.server_args.dp_size + self.pp_size = model_runner.server_args.pp_size # Batch sizes to capture self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) - + rank0_log(f"Capture cuda graph bs {self.capture_bs}") self.capture_forward_mode = ForwardMode.DECODE self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 @@ -234,6 +238,19 @@ class CudaGraphRunner: self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) + # pipeline parallelism + if self.pp_size > 1: + self.pp_proxy_tensors = { + "hidden_states": torch.zeros( + (self.max_bs, self.model_runner.model_config.hidden_size), + dtype=torch.bfloat16, + ), + "residual": torch.zeros( + (self.max_bs, self.model_runner.model_config.hidden_size), + dtype=torch.bfloat16, + ), + } + # Speculative_inference if ( model_runner.spec_algorithm.is_eagle3() @@ -384,6 +401,12 @@ class CudaGraphRunner: encoder_lens = None mrope_positions = self.mrope_positions[:, :bs] + # pipeline parallelism + if self.pp_size > 1: + pp_proxy_tensors = PPProxyTensors( + {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()} + ) + if self.enable_dp_attention or self.enable_sp_layernorm: self.global_num_tokens_gpu.copy_( torch.tensor( @@ -456,8 +479,20 @@ class CudaGraphRunner: # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None - logits_output = forward(input_ids, forward_batch.positions, forward_batch) - return logits_output.next_token_logits, logits_output.hidden_states + kwargs = {} + if ( + self.pp_size > 1 + and "pp_proxy_tensors" in inspect.signature(forward).parameters + ): + kwargs["pp_proxy_tensors"] = pp_proxy_tensors + + logits_output_or_pp_proxy_tensors = forward( + input_ids, + forward_batch.positions, + forward_batch, + **kwargs, + ) + return logits_output_or_pp_proxy_tensors for _ in range(2): torch.cuda.synchronize() @@ -490,7 +525,11 @@ class CudaGraphRunner: self.capture_hidden_mode = hidden_mode_from_spec_info self.capture() - def replay_prepare(self, forward_batch: ForwardBatch): + def replay_prepare( + self, + forward_batch: ForwardBatch, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ): self.recapture_if_needed(forward_batch) raw_bs = forward_batch.batch_size @@ -519,6 +558,11 @@ class CudaGraphRunner: self.seq_lens_cpu.fill_(1) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) + if pp_proxy_tensors: + for key in self.pp_proxy_tensors.keys(): + dim = pp_proxy_tensors[key].shape[0] + self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key]) + if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) if forward_batch.mrope_positions is not None: @@ -547,10 +591,13 @@ class CudaGraphRunner: self.bs = bs def replay( - self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False - ) -> LogitsProcessorOutput: + self, + forward_batch: ForwardBatch, + skip_attn_backend_init: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: if not skip_attn_backend_init: - self.replay_prepare(forward_batch) + self.replay_prepare(forward_batch, pp_proxy_tensors) else: # In speculative decoding, these two fields are still needed. self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) @@ -558,17 +605,19 @@ class CudaGraphRunner: # Replay self.graphs[self.bs].replay() - next_token_logits, hidden_states = self.output_buffers[self.bs] - - logits_output = LogitsProcessorOutput( - next_token_logits=next_token_logits[: self.raw_num_token], - hidden_states=( - hidden_states[: self.raw_num_token] - if hidden_states is not None - else None - ), - ) - return logits_output + output = self.output_buffers[self.bs] + if isinstance(output, LogitsProcessorOutput): + return LogitsProcessorOutput( + next_token_logits=output.next_token_logits[: self.raw_num_token], + hidden_states=( + output.hidden_states[: self.raw_num_token] + if output.hidden_states is not None + else None + ), + ) + else: + assert isinstance(output, PPProxyTensors) + return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()}) def get_spec_info(self, num_tokens: int): spec_info = None diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index e493dec7a..8f84c98e4 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -31,7 +31,7 @@ from __future__ import annotations from dataclasses import dataclass from enum import IntEnum, auto -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union import torch import triton @@ -585,6 +585,36 @@ class ForwardBatch: self.prepare_chunked_kv_indices(device) +class PPProxyTensors: + # adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103 + tensors: Dict[str, torch.Tensor] + + def __init__(self, tensors): + # manually define this function, so that + # Dynamo knows `IntermediateTensors()` comes from this file. + # Otherwise, dataclass will generate this function by evaluating + # a string, and we will lose the information about the source file. + self.tensors = tensors + + def __getitem__(self, key: Union[str, slice]): + if isinstance(key, str): + return self.tensors[key] + elif isinstance(key, slice): + return self.__class__({k: v[key] for k, v in self.tensors.items()}) + + def __setitem__(self, key: str, value: torch.Tensor): + self.tensors[key] = value + + def __len__(self): + return len(self.tensors) + + def __eq__(self, other: object): + return isinstance(other, self.__class__) and self + + def __repr__(self) -> str: + return f"PPProxyTensors(tensors={self.tensors})" + + def compute_position_triton( extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum ): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8b0ad93a5..5537daf18 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -13,8 +13,10 @@ # ============================================================================== """ModelRunner runs the forward passes of the models.""" +import collections import datetime import gc +import inspect import json import logging import os @@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import ( ) from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader import get_model from sglang.srt.model_loader.loader import ( DefaultModelLoader, @@ -111,6 +113,8 @@ class ModelRunner: gpu_id: int, tp_rank: int, tp_size: int, + pp_rank: int, + pp_size: int, nccl_port: int, server_args: ServerArgs, is_draft_worker: bool = False, @@ -124,6 +128,8 @@ class ModelRunner: self.gpu_id = gpu_id self.tp_rank = tp_rank self.tp_size = tp_size + self.pp_rank = pp_rank + self.pp_size = pp_size self.dist_port = nccl_port self.server_args = server_args self.is_draft_worker = is_draft_worker @@ -149,24 +155,24 @@ class ModelRunner: global_server_args_dict.update( { "attention_backend": server_args.attention_backend, - "sampling_backend": server_args.sampling_backend, - "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, - "torchao_config": server_args.torchao_config, + "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, + "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, + "deepep_mode": server_args.deepep_mode, + "device": server_args.device, + "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache, + "disable_radix_cache": server_args.disable_radix_cache, "enable_nan_detection": server_args.enable_nan_detection, "enable_dp_attention": server_args.enable_dp_attention, "enable_ep_moe": server_args.enable_ep_moe, "enable_deepep_moe": server_args.enable_deepep_moe, - "deepep_mode": server_args.deepep_mode, - "device": server_args.device, - "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, - "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, - "disable_radix_cache": server_args.disable_radix_cache, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "moe_dense_tp_size": server_args.moe_dense_tp_size, - "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, - "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, "n_share_experts_fusion": server_args.n_share_experts_fusion, - "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache, + "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, + "torchao_config": server_args.torchao_config, + "sampling_backend": server_args.sampling_backend, + "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, + "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, "use_mla_backend": self.use_mla_backend, } ) @@ -184,6 +190,11 @@ class ModelRunner: # If it is a draft model, tp_group can be different self.initialize(min_per_gpu_memory) + # temporary cached values + self.support_pp = ( + "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters + ) + def initialize(self, min_per_gpu_memory: float): server_args = self.server_args self.memory_saver_adapter = TorchMemorySaverAdapter.create( @@ -194,6 +205,12 @@ class ModelRunner: self.sampler = Sampler() self.load_model() + self.start_layer = getattr(self.model, "start_layer", 0) + self.end_layer = getattr( + self.model, "end_layer", self.model_config.num_hidden_layers + ) + self.num_effective_layers = self.end_layer - self.start_layer + # Apply torchao quantization torchao_applied = getattr(self.model, "torchao_applied", False) # In layered loading, torchao may have been applied @@ -360,18 +377,22 @@ class ModelRunner: # Only initialize the distributed environment on the target model worker. init_distributed_environment( backend=backend, - world_size=self.tp_size, - rank=self.tp_rank, + world_size=self.tp_size * self.pp_size, + rank=self.tp_size * self.pp_rank + self.tp_rank, local_rank=self.gpu_id, distributed_init_method=dist_init_method, timeout=self.server_args.dist_timeout, ) - initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + initialize_model_parallel( + tensor_model_parallel_size=self.tp_size, + pipeline_model_parallel_size=self.pp_size, + ) initialize_dp_attention( enable_dp_attention=self.server_args.enable_dp_attention, tp_rank=self.tp_rank, tp_size=self.tp_size, dp_size=self.server_args.dp_size, + pp_size=self.server_args.pp_size, ) min_per_gpu_memory = get_available_gpu_memory( @@ -698,6 +719,8 @@ class ModelRunner: if not self.is_draft_worker else self.model_config.hf_config.num_nextn_predict_layers ) + # FIXME: pipeline parallelism is not compatible with mla backend + assert self.pp_size == 1 cell_size = ( (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) * num_layers @@ -707,7 +730,7 @@ class ModelRunner: cell_size = ( self.model_config.get_num_kv_heads(get_attention_tp_size()) * self.model_config.head_dim - * self.model_config.num_hidden_layers + * self.num_effective_layers * 2 * torch._utils._element_size(self.kv_cache_dtype) ) @@ -819,9 +842,11 @@ class ModelRunner: self.model_config.num_hidden_layers if not self.is_draft_worker else self.model_config.hf_config.num_nextn_predict_layers - ), + ), # PP is not compatible with mla backend device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, ) elif self.server_args.enable_double_sparsity: self.token_to_kv_pool = DoubleSparseTokenToKVPool( @@ -830,10 +855,12 @@ class ModelRunner: dtype=self.kv_cache_dtype, head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_dim=self.model_config.head_dim, - layer_num=self.model_config.num_hidden_layers, + layer_num=self.num_effective_layers, device=self.device, heavy_channel_num=self.server_args.ds_heavy_channel_num, enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, ) else: self.token_to_kv_pool = MHATokenToKVPool( @@ -842,9 +869,11 @@ class ModelRunner: dtype=self.kv_cache_dtype, head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), head_dim=self.model_config.head_dim, - layer_num=self.model_config.num_hidden_layers, + layer_num=self.num_effective_layers, device=self.device, enable_memory_saver=self.server_args.enable_memory_saver, + start_layer=self.start_layer, + end_layer=self.end_layer, ) if self.token_to_kv_pool_allocator is None: @@ -957,7 +986,7 @@ class ModelRunner: with open(self.server_args.ds_channel_config_path, "r") as f: channel_config = json.load(f) - for i in range(self.model_config.num_hidden_layers): + for i in range(self.start_layer, self.end_layer): key = "model.layers." + str(i) + ".self_attn" + selected_channel self.sorted_channels.append( torch.tensor(channel_config[key])[ @@ -997,64 +1026,82 @@ class ModelRunner: device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,)) tensor_parallel(self.model, device_mesh) - def forward_decode(self, forward_batch: ForwardBatch): + def forward_decode( + self, forward_batch: ForwardBatch, pp_proxy_tensors=None + ) -> LogitsProcessorOutput: self.attn_backend.init_forward_metadata(forward_batch) + # FIXME: add pp_proxy_tensors arg to all models + kwargs = {} + if self.support_pp: + kwargs["pp_proxy_tensors"] = pp_proxy_tensors return self.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch + forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs ) def forward_extend( - self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False - ): + self, + forward_batch: ForwardBatch, + skip_attn_backend_init: bool = False, + pp_proxy_tensors=None, + ) -> LogitsProcessorOutput: if not skip_attn_backend_init: self.attn_backend.init_forward_metadata(forward_batch) - if self.is_generation: - if forward_batch.input_embeds is None: - return self.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch - ) - else: - return self.model.forward( - forward_batch.input_ids, - forward_batch.positions, - forward_batch, - input_embeds=forward_batch.input_embeds.bfloat16(), - ) - else: - # Only embedding models have get_embedding parameter - return self.model.forward( - forward_batch.input_ids, - forward_batch.positions, - forward_batch, - get_embedding=True, - ) - - def forward_idle(self, forward_batch: ForwardBatch): + kwargs = {} + if self.support_pp: + kwargs["pp_proxy_tensors"] = pp_proxy_tensors + if forward_batch.input_embeds is not None: + kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16() + if not self.is_generation: + kwargs["get_embedding"] = True return self.model.forward( - forward_batch.input_ids, forward_batch.positions, forward_batch + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + **kwargs, + ) + + def forward_idle( + self, forward_batch: ForwardBatch, pp_proxy_tensors=None + ) -> LogitsProcessorOutput: + kwargs = {} + if self.support_pp: + kwargs["pp_proxy_tensors"] = pp_proxy_tensors + return self.model.forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + **kwargs, ) def forward( - self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False - ) -> LogitsProcessorOutput: - if ( + self, + forward_batch: ForwardBatch, + skip_attn_backend_init: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[LogitsProcessorOutput, PPProxyTensors]: + can_run_cuda_graph = bool( forward_batch.forward_mode.is_cuda_graph() and self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch) - ): + ) + if can_run_cuda_graph: return self.cuda_graph_runner.replay( - forward_batch, skip_attn_backend_init=skip_attn_backend_init + forward_batch, + skip_attn_backend_init=skip_attn_backend_init, + pp_proxy_tensors=pp_proxy_tensors, ) if forward_batch.forward_mode.is_decode(): - return self.forward_decode(forward_batch) + return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors) elif forward_batch.forward_mode.is_extend(): return self.forward_extend( - forward_batch, skip_attn_backend_init=skip_attn_backend_init + forward_batch, + skip_attn_backend_init=skip_attn_backend_init, + pp_proxy_tensors=pp_proxy_tensors, ) elif forward_batch.forward_mode.is_idle(): - return self.forward_idle(forward_batch) + return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors) else: raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}") diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 008e54204..c6743e344 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -17,13 +17,14 @@ """Inference-only LLaMA model compatible with HuggingFace weights.""" import logging -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn from transformers import LlamaConfig from sglang.srt.distributed import ( + get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) @@ -39,11 +40,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope +from sglang.srt.layers.utils import PPMissingLayer, get_layer_id from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, @@ -275,21 +277,31 @@ class LlamaModel(nn.Module): self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=add_prefix("embed_tokens", prefix), - ) - self.layers = make_layers( + self.pp_group = get_pp_group() + if self.pp_group.is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=add_prefix("embed_tokens", prefix), + ) + else: + self.embed_tokens = PPMissingLayer() + + self.layers, self.start_layer, self.end_layer = make_layers( config.num_hidden_layers, lambda idx, prefix: LlamaDecoderLayer( - config=config, layer_id=idx, quant_config=quant_config, prefix=prefix + config=config, quant_config=quant_config, layer_id=idx, prefix=prefix ), + pp_rank=self.pp_group.rank_in_group, + pp_size=self.pp_group.world_size, prefix="model.layers", ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if self.pp_group.is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer(return_tuple=True) self.layers_to_capture = [] def forward( @@ -298,14 +310,23 @@ class LlamaModel(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, - ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: - if input_embeds is None: - hidden_states = self.embed_tokens(input_ids) + pp_proxy_tensors: Optional[PPProxyTensors] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]: + if self.pp_group.is_first_rank: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + residual = None else: - hidden_states = input_embeds - residual = None + assert pp_proxy_tensors is not None + # FIXME(@ying): reduce the number of proxy tensors by not fusing layer norms + hidden_states = pp_proxy_tensors["hidden_states"] + residual = pp_proxy_tensors["residual"] + deferred_norm = None + aux_hidden_states = [] - for i in range(len(self.layers)): + for i in range(self.start_layer, self.end_layer): if i in self.layers_to_capture: aux_hidden_states.append(hidden_states + residual) layer = self.layers[i] @@ -315,7 +336,16 @@ class LlamaModel(nn.Module): forward_batch, residual, ) - hidden_states, _ = self.norm(hidden_states, residual) + + if not self.pp_group.is_last_rank: + return PPProxyTensors( + { + "hidden_states": hidden_states, + "residual": residual, + } + ) + else: + hidden_states, _ = self.norm(hidden_states, residual) if len(aux_hidden_states) == 0: return hidden_states @@ -376,6 +406,7 @@ class LlamaForCausalLM(nn.Module): prefix: str = "", ) -> None: super().__init__() + self.pp_group = get_pp_group() self.config = config self.quant_config = quant_config self.model = self._init_model(config, quant_config, add_prefix("model", prefix)) @@ -419,23 +450,41 @@ class LlamaForCausalLM(nn.Module): forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, get_embedding: bool = False, + pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> LogitsProcessorOutput: + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds, + pp_proxy_tensors=pp_proxy_tensors, + ) + aux_hidden_states = None if self.capture_aux_hidden_states: - hidden_states, aux_hidden_states = self.model( - input_ids, positions, forward_batch, input_embeds - ) - else: - hidden_states = self.model( - input_ids, positions, forward_batch, input_embeds - ) + hidden_states, aux_hidden_states = hidden_states - if not get_embedding: - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states - ) + if self.pp_group.is_last_rank: + if not get_embedding: + return self.logits_processor( + input_ids, + hidden_states, + self.lm_head, + forward_batch, + aux_hidden_states, + ) + else: + return self.pooler(hidden_states, forward_batch) else: - return self.pooler(hidden_states, forward_batch) + return hidden_states + + @property + def start_layer(self): + return self.model.start_layer + + @property + def end_layer(self): + return self.model.end_layer def get_input_embeddings(self) -> nn.Embedding: return self.model.embed_tokens @@ -491,6 +540,16 @@ class LlamaForCausalLM(nn.Module): params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + layer_id = get_layer_id(name) + if ( + layer_id is not None + and hasattr(self.model, "start_layer") + and ( + layer_id < self.model.start_layer + or layer_id >= self.model.end_layer + ) + ): + continue if "rotary_emb.inv_freq" in name or "projector" in name: continue if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: @@ -637,6 +696,9 @@ class LlamaForCausalLM(nn.Module): self.model.load_kv_cache_scales(quantization_param_path) def set_eagle3_layers_to_capture(self): + if not self.pp_group.is_last_rank: + return + self.capture_aux_hidden_states = True num_layers = self.config.num_hidden_layers self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3] diff --git a/python/sglang/srt/models/llama4.py b/python/sglang/srt/models/llama4.py index 95edfa40e..73c707508 100644 --- a/python/sglang/srt/models/llama4.py +++ b/python/sglang/srt/models/llama4.py @@ -46,7 +46,7 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.rotary_embedding import get_rope from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers @@ -431,6 +431,7 @@ class Llama4Model(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) diff --git a/python/sglang/srt/models/llama_eagle.py b/python/sglang/srt/models/llama_eagle.py index b04d334bd..881731fdf 100644 --- a/python/sglang/srt/models/llama_eagle.py +++ b/python/sglang/srt/models/llama_eagle.py @@ -25,13 +25,14 @@ import torch from torch import nn from transformers import LlamaConfig +from sglang.srt.distributed import get_pp_group from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.models.llama import LlamaDecoderLayer, LlamaForCausalLM @@ -86,6 +87,7 @@ class LlamaModel(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: if input_embeds is None: hidden_states = self.embed_tokens(input_ids) @@ -118,6 +120,7 @@ class LlamaForCausalLMEagle(LlamaForCausalLM): nn.Module.__init__(self) self.config = config self.quant_config = quant_config + self.pp_group = get_pp_group() self.model = LlamaModel( config, quant_config=quant_config, prefix=add_prefix("model", prefix) ) diff --git a/python/sglang/srt/models/llama_eagle3.py b/python/sglang/srt/models/llama_eagle3.py index 137a6da56..bbc58ae60 100644 --- a/python/sglang/srt/models/llama_eagle3.py +++ b/python/sglang/srt/models/llama_eagle3.py @@ -25,6 +25,7 @@ import torch from torch import nn from transformers import LlamaConfig +from sglang.srt.distributed import get_pp_group from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear from sglang.srt.layers.logits_processor import LogitsProcessor @@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM @@ -118,6 +119,7 @@ class LlamaModel(nn.Module): positions: torch.Tensor, forward_batch: ForwardBatch, input_embeds: torch.Tensor = None, + pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> torch.Tensor: if input_embeds is None: embeds = self.embed_tokens(input_ids) @@ -155,6 +157,7 @@ class LlamaForCausalLMEagle3(LlamaForCausalLM): nn.Module.__init__(self) self.config = config self.quant_config = quant_config + self.pp_group = get_pp_group() if self.config.num_hidden_layers != 1: raise ValueError("EAGLE3 currently only supports 1 layer") diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index a23ee4ad5..1d7c2aa1a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -78,6 +78,8 @@ class ServerArgs: # Other runtime options tp_size: int = 1 + pp_size: int = 1 + max_micro_batch_size: Optional[int] = None stream_interval: int = 1 stream_output: bool = False random_seed: Optional[int] = None @@ -222,14 +224,18 @@ class ServerArgs: # Set mem fraction static, which depends on the tensor parallelism size if self.mem_fraction_static is None: - if self.tp_size >= 16: - self.mem_fraction_static = 0.79 - elif self.tp_size >= 8: - self.mem_fraction_static = 0.81 - elif self.tp_size >= 4: - self.mem_fraction_static = 0.85 - elif self.tp_size >= 2: - self.mem_fraction_static = 0.87 + parallel_size = self.tp_size * self.pp_size + if gpu_mem <= 81920: + if parallel_size >= 16: + self.mem_fraction_static = 0.79 + elif parallel_size >= 8: + self.mem_fraction_static = 0.81 + elif parallel_size >= 4: + self.mem_fraction_static = 0.85 + elif parallel_size >= 2: + self.mem_fraction_static = 0.87 + else: + self.mem_fraction_static = 0.88 else: self.mem_fraction_static = 0.88 if gpu_mem > 96 * 1024: @@ -244,6 +250,8 @@ class ServerArgs: if self.chunked_prefill_size is None: if gpu_mem is not None and gpu_mem < 25_000: self.chunked_prefill_size = 2048 + elif self.disaggregation_mode != "null": + self.chunked_prefill_size = 16384 else: self.chunked_prefill_size = 8192 assert self.chunked_prefill_size % self.page_size == 0 @@ -643,6 +651,19 @@ class ServerArgs: default=ServerArgs.tp_size, help="The tensor parallelism size.", ) + parser.add_argument( + "--pipeline-parallel-size", + "--pp-size", + type=int, + default=ServerArgs.pp_size, + help="The pipeline parallelism size.", + ) + parser.add_argument( + "--max-micro-batch-size", + type=int, + default=ServerArgs.max_micro_batch_size, + help="The maximum micro batch size in pipeline parallelism.", + ) parser.add_argument( "--stream-interval", type=int, @@ -1232,6 +1253,7 @@ class ServerArgs: @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size + args.pp_size = args.pipeline_parallel_size args.dp_size = args.data_parallel_size args.ep_size = args.expert_parallel_size attrs = [attr.name for attr in dataclasses.fields(cls)] @@ -1245,8 +1267,19 @@ class ServerArgs: def check_server_args(self): assert ( - self.tp_size % self.nnodes == 0 - ), "tp_size must be divisible by number of nodes" + self.tp_size * self.pp_size + ) % self.nnodes == 0, "tp_size must be divisible by number of nodes" + + # FIXME pp constraints + if self.pp_size > 1: + logger.warning(f"Turn off overlap scheule for pipeline parallelism.") + self.disable_overlap_schedule = True + assert ( + self.disable_overlap_schedule + and self.speculative_algorithm is None + and not self.enable_mixed_chunk + ), "Pipeline parallelism is not compatible with overlap schedule, speculative decoding, mixed chunked prefill." + assert not ( self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention ), "multi-node data parallel is not supported unless dp attention!" diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 64de7dbb4..d1fd04e93 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -106,11 +106,12 @@ class EAGLEWorker(TpModelWorker): # Init draft worker with empty_context(): super().__init__( + server_args=server_args, gpu_id=gpu_id, tp_rank=tp_rank, - server_args=server_args, - nccl_port=nccl_port, + pp_rank=0, # FIXME dp_rank=dp_rank, + nccl_port=nccl_port, is_draft_worker=True, req_to_token_pool=self.req_to_token_pool, token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index ef18002e0..3c68e6057 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -12,6 +12,7 @@ # limitations under the License. # ============================================================================== """Common utilities.""" + import base64 import builtins import ctypes @@ -414,16 +415,40 @@ class LayerFn(Protocol): def make_layers( num_hidden_layers: int, layer_fn: LayerFn, + pp_rank: Optional[int] = None, + pp_size: Optional[int] = None, prefix: str = "", + return_tuple: bool = False, ) -> Tuple[int, int, torch.nn.ModuleList]: """Make a list of layers with the given layer function""" + # circula imports + from sglang.srt.distributed import get_pp_indices + from sglang.srt.layers.utils import PPMissingLayer + + assert not pp_size or num_hidden_layers >= pp_size + start_layer, end_layer = ( + get_pp_indices( + num_hidden_layers, + pp_rank, + pp_size, + ) + if pp_rank is not None and pp_size is not None + else (0, num_hidden_layers) + ) modules = torch.nn.ModuleList( - [ + [PPMissingLayer(return_tuple=return_tuple) for _ in range(start_layer)] + + [ maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix))) - for idx in range(num_hidden_layers) + for idx in range(start_layer, end_layer) + ] + + [ + PPMissingLayer(return_tuple=return_tuple) + for _ in range(end_layer, num_hidden_layers) ] ) - return modules + if pp_rank is None or pp_size is None: + return modules + return modules, start_layer, end_layer def set_random_seed(seed: int) -> None: @@ -877,7 +902,7 @@ def broadcast_pyobj( "cuda" if torch.cuda.is_available() and not force_cpu_device else "cpu" ) - if rank == 0: + if rank == src: if len(data) == 0: tensor_size = torch.tensor([0], dtype=torch.long, device=device) dist.broadcast(tensor_size, src=src, group=dist_group) @@ -909,6 +934,50 @@ def broadcast_pyobj( return data +def point_to_point_pyobj( + data: List[Any], + rank: int, + group: Optional[torch.distributed.ProcessGroup] = None, + src: int = 0, + dst: int = 1, +): + """Send data from src to dst in group.""" + + if rank == src: + if len(data) == 0: + tensor_size = torch.tensor([0], dtype=torch.long) + dist.send(tensor_size, dst=dst, group=group) + else: + serialized_data = pickle.dumps(data) + size = len(serialized_data) + tensor_data = torch.ByteTensor( + np.frombuffer(serialized_data, dtype=np.uint8) + ) + tensor_size = torch.tensor([size], dtype=torch.long) + + dist.send(tensor_size, dst=dst, group=group) + dist.send(tensor_data, dst=dst, group=group) + return data + + elif rank == dst: + tensor_size = torch.tensor([0], dtype=torch.long) + dist.recv(tensor_size, src=src, group=group) + size = tensor_size.item() + + if size == 0: + return [] + + tensor_data = torch.empty(size, dtype=torch.uint8) + dist.recv(tensor_data, src=src, group=group) + + serialized_data = bytes(tensor_data.cpu().numpy()) + data = pickle.loads(serialized_data) + return data + + # Other ranks in pp_group do nothing + return [] + + step_counter = 0 @@ -1732,6 +1801,13 @@ def configure_ipv6(dist_init_addr): return port, host +def rank0_log(msg: str): + from sglang.srt.distributed import get_tensor_model_parallel_rank + + if get_tensor_model_parallel_rank() == 0: + logger.info(msg) + + def rank0_print(msg: str): from sglang.srt.distributed import get_tensor_model_parallel_rank diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 9db6cbd9f..7bf35582b 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -770,6 +770,34 @@ def run_bench_offline_throughput(model, other_args): return output_throughput +def run_bench_one_batch_server( + model, + base_url, + server_args, + bench_args, + other_server_args, + simulate_spec_acc_lens=None, +): + from sglang.bench_one_batch_server import run_benchmark + + if simulate_spec_acc_lens is not None: + env = {**os.environ, "SIMULATE_ACC_LEN": str(simulate_spec_acc_lens)} + else: + env = None + + process = popen_launch_server( + model, + base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_server_args, + env=env, + ) + try: + run_benchmark(server_args=server_args, bench_args=bench_args) + finally: + kill_process_tree(process.pid) + + def lcs(X, Y): m = len(X) n = len(Y) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 728739f75..9f593ca9f 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -96,6 +96,8 @@ suites = { "per-commit-8-gpu": [ TestFile("test_local_attn.py", 250), TestFile("test_full_deepseek_v3.py", 250), + TestFile("test_fa3.py", 30), + TestFile("test_pp_single_node.py", 150), ], "nightly": [ TestFile("test_nightly_gsm8k_eval.py"), diff --git a/test/srt/test_pp_single_node.py b/test/srt/test_pp_single_node.py new file mode 100644 index 000000000..4d0b4adac --- /dev/null +++ b/test/srt/test_pp_single_node.py @@ -0,0 +1,143 @@ +""" +Usage: +python3 -m unittest test_pp_single_node.TestPPAccuracy.test_gsm8k +python3 -m unittest test_pp_single_node.TestFixedBugs.test_chunked_prefill_with_small_bs +""" + +import os +import time +import unittest +from types import SimpleNamespace + +import requests + +from sglang.bench_one_batch_server import BenchArgs as OneBatchBenchArgs +from sglang.srt.server_args import ServerArgs +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval +from sglang.test.runners import DEFAULT_PROMPTS +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, + run_bench_one_batch_server, +) + + +class TestPPAccuracy(unittest.TestCase): + @classmethod + def setUpClass(cls): + # These config helps find a leak. + os.environ["SGLANG_IS_IN_CI"] = "1" + cls.base_url = "http://127.0.0.1:23333" + cls.process = popen_launch_server( + DEFAULT_MODEL_NAME_FOR_TEST, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--pp-size", + 4, + "--disable-overlap-schedule", + "--chunked-prefill-size", + 256, + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=5, + data_path=None, + num_questions=200, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval(args) + print(f"{metrics=}") + + self.assertGreater(metrics["accuracy"], 0.75) + # Wait a little bit so that the memory check happens. + time.sleep(5) + + +# class TestPPAccuracyFlashInfer(unittest.TestCase): +# @classmethod +# def setUpClass(cls): +# # These config helps find a leak. +# os.environ["SGLANG_IS_IN_CI"] = "1" +# cls.base_url = "http://127.0.0.1:23333" +# cls.process = popen_launch_server( +# DEFAULT_MODEL_NAME_FOR_TEST, +# cls.base_url, +# timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, +# other_args=[ +# "--pp-size", +# 4, +# "--disable-overlap-schedule", +# "--attention-backend", +# "flashinfer", +# "--chunked-prefill-size", +# 256, +# ], +# ) +# +# @classmethod +# def tearDownClass(cls): +# kill_process_tree(cls.process.pid) +# +# def test_gsm8k(self): +# args = SimpleNamespace( +# num_shots=5, +# data_path=None, +# num_questions=200, +# max_new_tokens=512, +# parallel=128, +# host="http://127.0.0.1", +# port=int(self.base_url.split(":")[-1]), +# ) +# metrics = run_eval(args) +# print(f"{metrics=}") +# +# self.assertGreater(metrics["accuracy"], 0.75) +# # Wait a little bit so that the memory check happens. +# time.sleep(5) + + +class TestFixedBugs(unittest.TestCase): + def test_chunked_prefill_with_small_bs(self): + model = DEFAULT_MODEL_NAME_FOR_TEST + server_args = ServerArgs(model_path=model) + bench_args = OneBatchBenchArgs( + batch_size=(1,), + input_len=(1,), + output_len=(1,), + base_url=DEFAULT_URL_FOR_TEST, + ) + other_server_args = [ + "--tp-size", + 2, + "--pp-size", + 2, + "--disable-overlap-schedule", + "--chunked-prefill", + 256, + "--max-running-requests", + 2, + ] + run_bench_one_batch_server( + model, + DEFAULT_URL_FOR_TEST, + server_args, + bench_args, + other_server_args, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_vlm_accuracy.py b/test/srt/test_vlm_accuracy.py index b89379b34..a56bc85dc 100644 --- a/test/srt/test_vlm_accuracy.py +++ b/test/srt/test_vlm_accuracy.py @@ -147,6 +147,8 @@ class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase): gpu_id=0, tp_rank=0, tp_size=1, + pp_rank=0, + pp_size=1, nccl_port=12435, server_args=ServerArgs( model_path=self.model_path,