Support multi-node DP attention (#2925)

Co-authored-by: dhou-xai <dhou@x.ai>
This commit is contained in:
Lianmin Zheng
2025-01-16 11:15:00 -08:00
committed by GitHub
parent 58f3f2b840
commit 8b6ce52e92
16 changed files with 287 additions and 137 deletions

View File

@@ -33,6 +33,7 @@ import zmq
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.io_struct import (
AbortReq,
@@ -135,7 +136,17 @@ class Scheduler:
# Init inter-process communication
context = zmq.Context(2)
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
self.dp_size = server_args.dp_size
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
compute_dp_attention_world_info(
server_args.enable_dp_attention,
self.tp_rank,
self.tp_size,
self.dp_size,
)
)
if self.attn_tp_rank == 0:
self.recv_from_tokenizer = get_zmq_socket(
context, zmq.PULL, port_args.scheduler_input_ipc_name
)
@@ -244,6 +255,7 @@ class Scheduler:
_,
) = self.tp_worker.get_worker_info()
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
global_server_args_dict.update(worker_global_server_args_dict)
set_random_seed(self.random_seed)
@@ -447,6 +459,10 @@ class Scheduler:
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
if self.server_args.enable_dp_attention: # TODO: simplify this
batch = self.prepare_dp_attn_batch(batch)
self.cur_batch = batch
if batch:
@@ -479,7 +495,7 @@ class Scheduler:
def recv_requests(self) -> List[Req]:
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
if self.attn_tp_rank == 0:
recv_reqs = []
while True:
@@ -491,7 +507,40 @@ class Scheduler:
else:
recv_reqs = None
if self.tp_size != 1 and not self.server_args.enable_dp_attention:
if self.server_args.enable_dp_attention:
if self.attn_tp_rank == 0:
work_reqs = [
req
for req in recv_reqs
if isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
)
]
control_reqs = [
req
for req in recv_reqs
if not isinstance(
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
)
]
else:
work_reqs = None
control_reqs = None
if self.attn_tp_size != 1:
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
work_reqs = broadcast_pyobj(
work_reqs,
self.attn_tp_rank,
self.attn_tp_cpu_group,
src=attn_tp_rank_0,
)
if self.tp_size != 1:
control_reqs = broadcast_pyobj(
control_reqs, self.tp_rank, self.tp_cpu_group
)
recv_reqs = work_reqs + control_reqs
elif self.tp_size != 1:
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
return recv_reqs
@@ -887,7 +936,7 @@ class Scheduler:
self.being_chunked_req.is_being_chunked += 1
# Print stats
if self.tp_rank == 0:
if self.attn_tp_rank == 0:
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
# Create a new batch
@@ -974,7 +1023,7 @@ class Scheduler:
self.forward_ct += 1
if self.is_generation:
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
if self.spec_algorithm.is_none():
model_worker_batch = batch.get_model_worker_batch()
logits_output, next_token_ids = (
@@ -988,18 +1037,8 @@ class Scheduler:
num_accepted_tokens,
) = self.draft_worker.forward_batch_speculative_generation(batch)
self.num_generated_tokens += num_accepted_tokens
elif batch.forward_mode.is_idle():
model_worker_batch = batch.get_model_worker_batch()
self.tp_worker.forward_batch_idle(model_worker_batch)
return
else:
logits_output = None
if self.skip_tokenizer_init:
next_token_ids = torch.full(
(batch.batch_size(),), self.tokenizer.eos_token_id
)
else:
next_token_ids = torch.full((batch.batch_size(),), 0)
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
batch.output_ids = next_token_ids
ret = logits_output, next_token_ids, model_worker_batch.bid
else: # embedding or reward model
@@ -1016,6 +1055,9 @@ class Scheduler:
self.running_batch = None
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_batch_result(result[-1])
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
@@ -1166,7 +1208,7 @@ class Scheduler:
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
if (
self.tp_rank == 0
self.attn_tp_rank == 0
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
):
self.log_decode_stats()
@@ -1402,12 +1444,7 @@ class Scheduler:
# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
dtype=torch.int32,
)
torch.distributed.all_reduce(