From 8b6ce52e92ab390952e75e2fc68c90d4e3f7928c Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 16 Jan 2025 11:15:00 -0800 Subject: [PATCH] Support multi-node DP attention (#2925) Co-authored-by: dhou-xai --- docs/backend/server_arguments.md | 4 +- docs/references/llama_405B.md | 4 +- .../layers/attention/flashinfer_backend.py | 19 +-- .../srt/layers/attention/triton_backend.py | 10 +- python/sglang/srt/layers/dp_attention.py | 68 +++++++++ python/sglang/srt/layers/logits_processor.py | 2 +- .../srt/managers/data_parallel_controller.py | 140 +++++++++--------- python/sglang/srt/managers/schedule_batch.py | 7 +- python/sglang/srt/managers/scheduler.py | 83 ++++++++--- python/sglang/srt/managers/tp_worker.py | 8 +- .../srt/managers/tp_worker_overlap_thread.py | 3 + .../srt/model_executor/cuda_graph_runner.py | 3 +- .../sglang/srt/model_executor/model_runner.py | 11 ++ python/sglang/srt/models/deepseek_v2.py | 7 +- python/sglang/srt/server_args.py | 51 +++++-- python/sglang/srt/utils.py | 4 +- 16 files changed, 287 insertions(+), 137 deletions(-) create mode 100644 python/sglang/srt/layers/dp_attention.py diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 90b36a0bd..6d72aa55a 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -26,8 +26,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct - To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph` ``` # Node 0 -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 0 # Node 1 -python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1 +python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 1 ``` diff --git a/docs/references/llama_405B.md b/docs/references/llama_405B.md index 4f70e89f6..075aac030 100644 --- a/docs/references/llama_405B.md +++ b/docs/references/llama_405B.md @@ -11,9 +11,9 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instr ```bash # on the first node, replace 172.16.4.52:20000 with your own node ip address and port -python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 +python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --dist-init-addr-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0 # on the second node, replace 172.18.45.52:20000 with your own node ip address and port -python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.18.45.52:20000 --nnodes 2 --node-rank 1 +python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --dist-init-addr-addr 172.18.45.52:20000 --nnodes 2 --node-rank 1 ``` diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 6a4636128..7540515c5 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -18,6 +18,7 @@ import triton.language as tl from sglang.global_config import global_config from sglang.srt.layers.attention import AttentionBackend +from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available @@ -62,9 +63,9 @@ class FlashInferAttnBackend(AttentionBackend): self.decode_use_tensor_cores = should_use_tensor_core( kv_cache_dtype=model_runner.kv_cache_dtype, num_attention_heads=model_runner.model_config.num_attention_heads - // model_runner.tp_size, + // get_attention_tp_size(), num_kv_heads=model_runner.model_config.get_num_kv_heads( - model_runner.tp_size + get_attention_tp_size() ), ) self.max_context_len = model_runner.model_config.context_len @@ -147,7 +148,7 @@ class FlashInferAttnBackend(AttentionBackend): self.prefill_cuda_graph_metadata = {} def init_forward_metadata(self, forward_batch: ForwardBatch): - if forward_batch.forward_mode.is_decode(): + if forward_batch.forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( forward_batch.req_pool_indices, forward_batch.seq_lens, @@ -238,7 +239,7 @@ class FlashInferAttnBackend(AttentionBackend): forward_mode: ForwardMode, spec_info: Optional[SpecInfo], ): - if forward_mode.is_decode(): + if forward_mode.is_decode_or_idle(): decode_wrappers = [] for i in range(self.num_wrappers): decode_wrappers.append( @@ -307,7 +308,7 @@ class FlashInferAttnBackend(AttentionBackend): forward_mode: ForwardMode, spec_info: Optional[SpecInfo], ): - if forward_mode.is_decode(): + if forward_mode.is_decode_or_idle(): self.indices_updater_decode.update( req_pool_indices[:bs], seq_lens[:bs], @@ -453,10 +454,10 @@ class FlashInferIndicesUpdaterDecode: def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): # Parse Constants self.num_qo_heads = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size + model_runner.model_config.num_attention_heads // get_attention_tp_size() ) self.num_kv_heads = model_runner.model_config.get_num_kv_heads( - model_runner.tp_size + get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim self.data_type = model_runner.kv_cache_dtype @@ -625,10 +626,10 @@ class FlashInferIndicesUpdaterPrefill: def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend): # Parse Constants self.num_qo_heads = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size + model_runner.model_config.num_attention_heads // get_attention_tp_size() ) self.num_kv_heads = model_runner.model_config.get_num_kv_heads( - model_runner.tp_size + get_attention_tp_size() ) self.head_dim = model_runner.model_config.head_dim self.data_type = model_runner.kv_cache_dtype diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 04327b162..fade8ed29 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional import torch from sglang.srt.layers.attention import AttentionBackend +from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode if TYPE_CHECKING: @@ -28,12 +29,9 @@ class TritonAttnBackend(AttentionBackend): self.decode_attention_fwd = decode_attention_fwd self.extend_attention_fwd = extend_attention_fwd - if model_runner.server_args.enable_dp_attention: - self.num_head = model_runner.model_config.num_attention_heads - else: - self.num_head = ( - model_runner.model_config.num_attention_heads // model_runner.tp_size - ) + self.num_head = ( + model_runner.model_config.num_attention_heads // get_attention_tp_size() + ) self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] diff --git a/python/sglang/srt/layers/dp_attention.py b/python/sglang/srt/layers/dp_attention.py new file mode 100644 index 000000000..41bcb2181 --- /dev/null +++ b/python/sglang/srt/layers/dp_attention.py @@ -0,0 +1,68 @@ +import torch +from vllm.distributed import GroupCoordinator, get_tp_group + +_ATTN_TP_GROUP = None +_ATTN_TP_RANK = None +_ATTN_TP_SIZE = None +_DP_RANK = None +_DP_SIZE = None + + +def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size): + if not enable_dp_attention: + return tp_rank, tp_size, 0 + + attn_tp_size = tp_size // dp_size + dp_rank = tp_rank // attn_tp_size + attn_tp_rank = tp_rank % attn_tp_size + return attn_tp_rank, attn_tp_size, dp_rank + + +def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size): + global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE + + _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info( + enable_dp_attention, tp_rank, tp_size, dp_size + ) + _DP_SIZE = dp_size + + tp_group = get_tp_group() + _ATTN_TP_GROUP = GroupCoordinator( + [ + list(range(head, head + _ATTN_TP_SIZE)) + for head in range(0, tp_size, _ATTN_TP_SIZE) + ], + tp_rank, + torch.distributed.get_backend(tp_group.device_group), + False, + False, + False, + False, + False, + group_name="attention_tp", + ) + + +def get_attention_tp_group(): + assert _ATTN_TP_GROUP is not None, "dp attention not initialized!" + return _ATTN_TP_GROUP + + +def get_attention_tp_rank(): + assert _ATTN_TP_RANK is not None, "dp attention not initialized!" + return _ATTN_TP_RANK + + +def get_attention_tp_size(): + assert _ATTN_TP_SIZE is not None, "dp attention not initialized!" + return _ATTN_TP_SIZE + + +def get_attention_dp_rank(): + assert _DP_RANK is not None, "dp attention not initialized!" + return _DP_RANK + + +def get_attention_dp_size(): + assert _DP_SIZE is not None, "dp attention not initialized!" + return _DP_SIZE diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index f5b12b48a..e1dc94548 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -133,7 +133,7 @@ class LogitsProcessor(nn.Module): # Get the last hidden states and last logits for the next token prediction if ( - logits_metadata.forward_mode.is_decode() + logits_metadata.forward_mode.is_decode_or_idle() or logits_metadata.forward_mode.is_target_verify() ): last_index = None diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 7ae6689ee..c4ebbb3cf 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -23,6 +23,7 @@ import psutil import setproctitle import zmq +from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.managers.io_struct import ( TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, @@ -63,9 +64,10 @@ class DataParallelController: # Init inter-process communication self.context = zmq.Context(1 + server_args.dp_size) - self.recv_from_tokenizer = get_zmq_socket( - self.context, zmq.PULL, port_args.scheduler_input_ipc_name - ) + if server_args.node_rank == 0: + self.recv_from_tokenizer = get_zmq_socket( + self.context, zmq.PULL, port_args.scheduler_input_ipc_name + ) # Dispatch method self.round_robin_counter = 0 @@ -75,33 +77,47 @@ class DataParallelController: } self.dispatching = dispatch_lookup[self.load_balance_method] - # Start data parallel workers - base_gpu_id = 0 + # Launch data parallel workers + self.scheduler_procs = [] self.workers = [None] * server_args.dp_size + if not server_args.enable_dp_attention: + dp_port_args = self.launch_dp_schedulers(server_args, port_args) + else: + dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args) + + # Only node rank 0 runs the real data parallel controller that dispatches the requests. + if server_args.node_rank == 0: + for dp_rank in range(server_args.dp_size): + self.workers[dp_rank] = get_zmq_socket( + self.context, + zmq.PUSH, + dp_port_args[dp_rank].scheduler_input_ipc_name, + ) + + def launch_dp_schedulers(self, server_args, port_args): + base_gpu_id = 0 + threads = [] sockets = [] + dp_port_args = [] for dp_rank in range(server_args.dp_size): tmp_port_args = PortArgs.init_new(server_args) tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name + dp_port_args.append(tmp_port_args) - if server_args.enable_dp_attention: - # Data parallelism resues the tensor parallelism group, - # so all dp ranks should use the same nccl port. - tmp_port_args.nccl_port = port_args.nccl_port - else: - # This port is checked free in PortArgs.init_new. - # We hold it first so that the next dp worker gets a different port - sockets.append(bind_port(tmp_port_args.nccl_port)) + # This port is checked free in PortArgs.init_new. + # We hold it first so that the next dp worker gets a different port + sockets.append(bind_port(tmp_port_args.nccl_port)) # Create a thread for each worker thread = threading.Thread( - target=self.launch_worker_func, + target=self.launch_tensor_parallel_group, args=(server_args, tmp_port_args, base_gpu_id, dp_rank), ) threads.append(thread) - base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size + base_gpu_id += server_args.tp_size # Free all sockets before starting the threads to launch TP workers for sock in sockets: @@ -113,26 +129,14 @@ class DataParallelController: for thread in threads: thread.join() - def launch_worker_func( - self, - server_args: ServerArgs, - port_args: PortArgs, - base_gpu_id: int, - dp_rank: int, - ): - logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.") + return dp_port_args - launch_func_ = ( - self.launch_tensor_parallel_process - if server_args.enable_dp_attention - else self.launch_tensor_parallel_group - ) - self.workers[dp_rank] = launch_func_( - server_args, - port_args, - base_gpu_id, - dp_rank, - ) + def launch_dp_attention_schedulers(self, server_args, port_args): + self.launch_tensor_parallel_group(server_args, port_args, 0, None) + dp_port_args = [] + for dp_rank in range(server_args.dp_size): + dp_port_args.append(PortArgs.init_new(server_args, dp_rank)) + return dp_port_args def launch_tensor_parallel_group( self, @@ -141,8 +145,10 @@ class DataParallelController: base_gpu_id: int, dp_rank: int, ): + if not server_args.enable_dp_attention: + logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.") + # Launch tensor parallel scheduler processes - scheduler_procs = [] scheduler_pipe_readers = [] tp_size_per_node = server_args.tp_size // server_args.nnodes tp_rank_range = range( @@ -150,53 +156,39 @@ class DataParallelController: tp_size_per_node * (server_args.node_rank + 1), ) for tp_rank in tp_rank_range: + rank_port_args = port_args + + if server_args.enable_dp_attention: + # dp attention has different sharding logic + _, _, dp_rank = compute_dp_attention_world_info( + server_args.enable_dp_attention, + tp_rank, + server_args.tp_size, + server_args.dp_size, + ) + # compute zmq ports for this dp rank + rank_port_args = PortArgs.init_new(server_args, dp_rank) + # Data parallelism resues the tensor parallelism group, + # so all dp ranks should use the same nccl port. + rank_port_args.nccl_port = port_args.nccl_port + reader, writer = mp.Pipe(duplex=False) gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node proc = mp.Process( target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), + args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer), ) proc.start() - scheduler_procs.append(proc) + self.scheduler_procs.append(proc) scheduler_pipe_readers.append(reader) - send_to = get_zmq_socket( - self.context, zmq.PUSH, port_args.scheduler_input_ipc_name - ) - - # Wait for model to finish loading and get max token nums + # Wait for model to finish loading scheduler_info = [] for i in range(len(scheduler_pipe_readers)): scheduler_info.append(scheduler_pipe_readers[i].recv()) self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"] - return send_to - - def launch_tensor_parallel_process( - self, - server_args: ServerArgs, - port_args: PortArgs, - base_gpu_id: int, - dp_rank: int, - ): - reader, writer = mp.Pipe(duplex=False) - gpu_id = base_gpu_id - tp_rank = dp_rank - proc = mp.Process( - target=run_scheduler_process, - args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer), - ) - proc.start() - send_to = get_zmq_socket( - self.context, zmq.PUSH, port_args.scheduler_input_ipc_name - ) - - scheduler_info = reader.recv() - self.max_total_num_tokens = scheduler_info["max_total_num_tokens"] - - return send_to - def round_robin_scheduler(self, req): self.workers[self.round_robin_counter].send_pyobj(req) self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers) @@ -221,8 +213,8 @@ class DataParallelController: ): self.dispatching(recv_req) else: - # Send other control messages to all workers - for worker in self.workers: + # Send other control messages to first worker of tp group + for worker in self.workers[:: self.server_args.tp_size]: worker.send_pyobj(recv_req) @@ -240,7 +232,13 @@ def run_data_parallel_controller_process( pipe_writer.send( {"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens} ) - controller.event_loop() + if server_args.node_rank == 0: + controller.event_loop() + for proc in controller.scheduler_procs: + proc.join() + logger.error( + f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}" + ) except Exception: traceback = get_exception_traceback() logger.error(f"DataParallelController hit an exception: {traceback}") diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c375df234..654c944ca 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1003,6 +1003,11 @@ class ScheduleBatch: self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device) self.seq_lens_sum = 0 self.extend_num_tokens = 0 + self.sampling_info = SamplingBatchInfo.from_schedule_batch( + self, + self.model_config.vocab_size, + enable_overlap_schedule=self.enable_overlap, + ) def prepare_for_decode(self): self.forward_mode = ForwardMode.DECODE @@ -1117,7 +1122,7 @@ class ScheduleBatch: self.spec_info.merge_batch(other.spec_info) def get_model_worker_batch(self): - if self.forward_mode.is_decode() or self.forward_mode.is_idle(): + if self.forward_mode.is_decode_or_idle(): extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None else: extend_seq_lens = self.extend_lens diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6ee93b3cd..62dc22ef2 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -33,6 +33,7 @@ import zmq from sglang.global_config import global_config from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer +from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, @@ -135,7 +136,17 @@ class Scheduler: # Init inter-process communication context = zmq.Context(2) - if self.tp_rank == 0 or self.server_args.enable_dp_attention: + self.dp_size = server_args.dp_size + self.attn_tp_rank, self.attn_tp_size, self.dp_rank = ( + compute_dp_attention_world_info( + server_args.enable_dp_attention, + self.tp_rank, + self.tp_size, + self.dp_size, + ) + ) + + if self.attn_tp_rank == 0: self.recv_from_tokenizer = get_zmq_socket( context, zmq.PULL, port_args.scheduler_input_ipc_name ) @@ -244,6 +255,7 @@ class Scheduler: _, ) = self.tp_worker.get_worker_info() self.tp_cpu_group = self.tp_worker.get_tp_cpu_group() + self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group() self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() global_server_args_dict.update(worker_global_server_args_dict) set_random_seed(self.random_seed) @@ -447,6 +459,10 @@ class Scheduler: self.process_input_requests(recv_reqs) batch = self.get_next_batch_to_run() + + if self.server_args.enable_dp_attention: # TODO: simplify this + batch = self.prepare_dp_attn_batch(batch) + self.cur_batch = batch if batch: @@ -479,7 +495,7 @@ class Scheduler: def recv_requests(self) -> List[Req]: """Receive results at tp_rank = 0 and broadcast it to all other TP ranks.""" - if self.tp_rank == 0 or self.server_args.enable_dp_attention: + if self.attn_tp_rank == 0: recv_reqs = [] while True: @@ -491,7 +507,40 @@ class Scheduler: else: recv_reqs = None - if self.tp_size != 1 and not self.server_args.enable_dp_attention: + if self.server_args.enable_dp_attention: + if self.attn_tp_rank == 0: + work_reqs = [ + req + for req in recv_reqs + if isinstance( + req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + ) + ] + control_reqs = [ + req + for req in recv_reqs + if not isinstance( + req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) + ) + ] + else: + work_reqs = None + control_reqs = None + + if self.attn_tp_size != 1: + attn_tp_rank_0 = self.dp_rank * self.attn_tp_size + work_reqs = broadcast_pyobj( + work_reqs, + self.attn_tp_rank, + self.attn_tp_cpu_group, + src=attn_tp_rank_0, + ) + if self.tp_size != 1: + control_reqs = broadcast_pyobj( + control_reqs, self.tp_rank, self.tp_cpu_group + ) + recv_reqs = work_reqs + control_reqs + elif self.tp_size != 1: recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) return recv_reqs @@ -887,7 +936,7 @@ class Scheduler: self.being_chunked_req.is_being_chunked += 1 # Print stats - if self.tp_rank == 0: + if self.attn_tp_rank == 0: self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked) # Create a new batch @@ -974,7 +1023,7 @@ class Scheduler: self.forward_ct += 1 if self.is_generation: - if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: + if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0: if self.spec_algorithm.is_none(): model_worker_batch = batch.get_model_worker_batch() logits_output, next_token_ids = ( @@ -988,18 +1037,8 @@ class Scheduler: num_accepted_tokens, ) = self.draft_worker.forward_batch_speculative_generation(batch) self.num_generated_tokens += num_accepted_tokens - elif batch.forward_mode.is_idle(): - model_worker_batch = batch.get_model_worker_batch() - self.tp_worker.forward_batch_idle(model_worker_batch) - return else: - logits_output = None - if self.skip_tokenizer_init: - next_token_ids = torch.full( - (batch.batch_size(),), self.tokenizer.eos_token_id - ) - else: - next_token_ids = torch.full((batch.batch_size(),), 0) + assert False, "batch.extend_num_tokens == 0, this is unexpected!" batch.output_ids = next_token_ids ret = logits_output, next_token_ids, model_worker_batch.bid else: # embedding or reward model @@ -1016,6 +1055,9 @@ class Scheduler: self.running_batch = None elif batch.forward_mode.is_extend(): self.process_batch_result_prefill(batch, result) + elif batch.forward_mode.is_idle(): + if self.enable_overlap: + self.tp_worker.resolve_batch_result(result[-1]) elif batch.forward_mode.is_dummy_first(): batch.next_batch_sampling_info.update_regex_vocab_mask() self.current_stream.synchronize() @@ -1166,7 +1208,7 @@ class Scheduler: self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30) if ( - self.tp_rank == 0 + self.attn_tp_rank == 0 and self.forward_ct_decode % self.server_args.decode_log_interval == 0 ): self.log_decode_stats() @@ -1402,12 +1444,7 @@ class Scheduler: # Check forward mode for cuda graph if not self.server_args.disable_cuda_graph: forward_mode_state = torch.tensor( - ( - 1 - if local_batch.forward_mode.is_decode() - or local_batch.forward_mode.is_idle() - else 0 - ), + (1 if local_batch.forward_mode.is_decode_or_idle() else 0), dtype=torch.int32, ) torch.distributed.all_reduce( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 25a1c85f2..47e3eea40 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -101,6 +101,7 @@ class TpModelWorker: self.max_total_num_tokens // 2 if server_args.max_running_requests is None else server_args.max_running_requests + // (server_args.dp_size if server_args.enable_dp_attention else 1) ), self.model_runner.req_to_token_pool.size, ) @@ -142,16 +143,15 @@ class TpModelWorker: def get_tp_cpu_group(self): return self.model_runner.tp_group.cpu_group + def get_attention_tp_cpu_group(self): + return self.model_runner.attention_tp_group.cpu_group + def get_memory_pool(self): return ( self.model_runner.req_to_token_pool, self.model_runner.token_to_kv_pool, ) - def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch): - forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) - self.model_runner.forward(forward_batch) - def forward_batch_generation( self, model_worker_batch: ModelWorkerBatch, diff --git a/python/sglang/srt/managers/tp_worker_overlap_thread.py b/python/sglang/srt/managers/tp_worker_overlap_thread.py index 2aa9c8269..64c34a851 100644 --- a/python/sglang/srt/managers/tp_worker_overlap_thread.py +++ b/python/sglang/srt/managers/tp_worker_overlap_thread.py @@ -92,6 +92,9 @@ class TpModelWorkerClient: def get_tp_cpu_group(self): return self.worker.get_tp_cpu_group() + def get_attention_tp_cpu_group(self): + return self.worker.get_attention_tp_cpu_group() + def get_memory_pool(self): return ( self.worker.model_runner.req_to_token_pool, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index e4580b5e2..e167ff16a 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -122,6 +122,7 @@ class CudaGraphRunner: self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention self.tp_size = self.model_runner.tp_size + self.dp_size = self.model_runner.server_args.dp_size # Batch sizes to capture self.capture_bs = self.model_runner.server_args.cuda_graph_bs @@ -218,7 +219,7 @@ class CudaGraphRunner: if self.enable_dp_attention: self.gathered_buffer = torch.zeros( ( - self.max_bs * self.tp_size, + self.max_bs * self.dp_size, self.model_runner.model_config.hidden_size, ), dtype=self.model_runner.dtype, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 238f8603a..d238c9195 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -35,6 +35,10 @@ from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttn from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend +from sglang.srt.layers.dp_attention import ( + get_attention_tp_group, + initialize_dp_attention, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model @@ -235,11 +239,18 @@ class ModelRunner: distributed_init_method=dist_init_method, ) initialize_model_parallel(tensor_model_parallel_size=self.tp_size) + initialize_dp_attention( + enable_dp_attention=self.server_args.enable_dp_attention, + tp_rank=self.tp_rank, + tp_size=self.tp_size, + dp_size=self.server_args.dp_size, + ) min_per_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, distributed=self.tp_size > 1 ) self.tp_group = get_tp_group() + self.attention_tp_group = get_attention_tp_group() # Check memory for tensor parallelism if self.tp_size > 1: diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index a9c0b59ce..19a73a86e 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -855,10 +855,9 @@ class DeepseekV2ForCausalLM(nn.Module): forward_batch: ForwardBatch, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, forward_batch) - if not forward_batch.forward_mode.is_idle(): - return self.logits_processor( - input_ids, hidden_states, self.lm_head, forward_batch - ) + return self.logits_processor( + input_ids, hidden_states, self.lm_head, forward_batch + ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e445217b6..df98bdeb3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -239,15 +239,14 @@ class ServerArgs: # Others if self.enable_dp_attention: + assert self.tp_size % self.dp_size == 0 self.dp_size = self.tp_size self.chunked_prefill_size = self.chunked_prefill_size // 2 self.schedule_conservativeness = self.schedule_conservativeness * 0.3 - self.disable_overlap_schedule = True logger.warning( f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. " f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. " "Data parallel size is adjusted to be the same as tensor parallel size. " - "Overlap scheduler is disabled." ) # Speculative Decoding @@ -880,8 +879,8 @@ class ServerArgs: self.tp_size % self.nnodes == 0 ), "tp_size must be divisible by number of nodes" assert not ( - self.dp_size > 1 and self.nnodes != 1 - ), "multi-node data parallel is not supported" + self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention + ), "multi-node data parallel is not supported unless dp attention!" assert ( self.max_loras_per_batch > 0 # FIXME @@ -919,6 +918,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs: return server_args +ZMQ_TCP_PORT_DELTA = 233 + + @dataclasses.dataclass class PortArgs: # The ipc filename for tokenizer to receive inputs from detokenizer (zmq) @@ -932,7 +934,7 @@ class PortArgs: nccl_port: int @staticmethod - def init_new(server_args) -> "PortArgs": + def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs": port = server_args.port + random.randint(100, 1000) while True: if is_port_available(port): @@ -942,12 +944,39 @@ class PortArgs: else: port -= 43 - return PortArgs( - tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, - scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, - detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name, - nccl_port=port, - ) + if not server_args.enable_dp_attention: + # Normal case, use IPC within a single node + return PortArgs( + tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}", + nccl_port=port, + ) + else: + # DP attention. Use TCP + port to handle both single-node and multi-node. + if server_args.nnodes == 1 and server_args.dist_init_addr is None: + dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA) + else: + dist_init_addr = server_args.dist_init_addr.split(":") + assert ( + len(dist_init_addr) == 2 + ), "please provide --dist-init-addr as host:port of head node" + + dist_init_host, dist_init_port = dist_init_addr + port_base = int(dist_init_port) + 1 + if dp_rank is None: + scheduler_input_port = ( + port_base + 2 + ) # TokenizerManager to DataParallelController + else: + scheduler_input_port = port_base + 2 + 1 + dp_rank + + return PortArgs( + tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}", + scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}", + detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}", + nccl_port=port, + ) class LoRAPathAction(argparse.Action): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index 583dd92e1..f1603ec0e 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -802,11 +802,11 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint: if socket_type == zmq.PUSH: socket.setsockopt(zmq.SNDHWM, 0) socket.setsockopt(zmq.SNDBUF, buf_size) - socket.connect(f"ipc://{endpoint}") + socket.connect(endpoint) elif socket_type == zmq.PULL: socket.setsockopt(zmq.RCVHWM, 0) socket.setsockopt(zmq.RCVBUF, buf_size) - socket.bind(f"ipc://{endpoint}") + socket.bind(endpoint) else: raise ValueError(f"Unsupported socket type: {socket_type}")