Support multi-node DP attention (#2925)
Co-authored-by: dhou-xai <dhou@x.ai>
This commit is contained in:
@@ -33,6 +33,7 @@ import zmq
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
@@ -135,7 +136,17 @@ class Scheduler:
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
|
||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||
self.dp_size = server_args.dp_size
|
||||
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
||||
compute_dp_attention_world_info(
|
||||
server_args.enable_dp_attention,
|
||||
self.tp_rank,
|
||||
self.tp_size,
|
||||
self.dp_size,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_tp_rank == 0:
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
||||
)
|
||||
@@ -244,6 +255,7 @@ class Scheduler:
|
||||
_,
|
||||
) = self.tp_worker.get_worker_info()
|
||||
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
||||
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
|
||||
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
||||
global_server_args_dict.update(worker_global_server_args_dict)
|
||||
set_random_seed(self.random_seed)
|
||||
@@ -447,6 +459,10 @@ class Scheduler:
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
|
||||
if self.server_args.enable_dp_attention: # TODO: simplify this
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
@@ -479,7 +495,7 @@ class Scheduler:
|
||||
|
||||
def recv_requests(self) -> List[Req]:
|
||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||
if self.attn_tp_rank == 0:
|
||||
recv_reqs = []
|
||||
|
||||
while True:
|
||||
@@ -491,7 +507,40 @@ class Scheduler:
|
||||
else:
|
||||
recv_reqs = None
|
||||
|
||||
if self.tp_size != 1 and not self.server_args.enable_dp_attention:
|
||||
if self.server_args.enable_dp_attention:
|
||||
if self.attn_tp_rank == 0:
|
||||
work_reqs = [
|
||||
req
|
||||
for req in recv_reqs
|
||||
if isinstance(
|
||||
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
||||
)
|
||||
]
|
||||
control_reqs = [
|
||||
req
|
||||
for req in recv_reqs
|
||||
if not isinstance(
|
||||
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
||||
)
|
||||
]
|
||||
else:
|
||||
work_reqs = None
|
||||
control_reqs = None
|
||||
|
||||
if self.attn_tp_size != 1:
|
||||
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
|
||||
work_reqs = broadcast_pyobj(
|
||||
work_reqs,
|
||||
self.attn_tp_rank,
|
||||
self.attn_tp_cpu_group,
|
||||
src=attn_tp_rank_0,
|
||||
)
|
||||
if self.tp_size != 1:
|
||||
control_reqs = broadcast_pyobj(
|
||||
control_reqs, self.tp_rank, self.tp_cpu_group
|
||||
)
|
||||
recv_reqs = work_reqs + control_reqs
|
||||
elif self.tp_size != 1:
|
||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
||||
return recv_reqs
|
||||
|
||||
@@ -887,7 +936,7 @@ class Scheduler:
|
||||
self.being_chunked_req.is_being_chunked += 1
|
||||
|
||||
# Print stats
|
||||
if self.tp_rank == 0:
|
||||
if self.attn_tp_rank == 0:
|
||||
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
|
||||
|
||||
# Create a new batch
|
||||
@@ -974,7 +1023,7 @@ class Scheduler:
|
||||
self.forward_ct += 1
|
||||
|
||||
if self.is_generation:
|
||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||
if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
|
||||
if self.spec_algorithm.is_none():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = (
|
||||
@@ -988,18 +1037,8 @@ class Scheduler:
|
||||
num_accepted_tokens,
|
||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
elif batch.forward_mode.is_idle():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
self.tp_worker.forward_batch_idle(model_worker_batch)
|
||||
return
|
||||
else:
|
||||
logits_output = None
|
||||
if self.skip_tokenizer_init:
|
||||
next_token_ids = torch.full(
|
||||
(batch.batch_size(),), self.tokenizer.eos_token_id
|
||||
)
|
||||
else:
|
||||
next_token_ids = torch.full((batch.batch_size(),), 0)
|
||||
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
|
||||
batch.output_ids = next_token_ids
|
||||
ret = logits_output, next_token_ids, model_worker_batch.bid
|
||||
else: # embedding or reward model
|
||||
@@ -1016,6 +1055,9 @@ class Scheduler:
|
||||
self.running_batch = None
|
||||
elif batch.forward_mode.is_extend():
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
elif batch.forward_mode.is_idle():
|
||||
if self.enable_overlap:
|
||||
self.tp_worker.resolve_batch_result(result[-1])
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
@@ -1166,7 +1208,7 @@ class Scheduler:
|
||||
|
||||
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
||||
if (
|
||||
self.tp_rank == 0
|
||||
self.attn_tp_rank == 0
|
||||
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
||||
):
|
||||
self.log_decode_stats()
|
||||
@@ -1402,12 +1444,7 @@ class Scheduler:
|
||||
# Check forward mode for cuda graph
|
||||
if not self.server_args.disable_cuda_graph:
|
||||
forward_mode_state = torch.tensor(
|
||||
(
|
||||
1
|
||||
if local_batch.forward_mode.is_decode()
|
||||
or local_batch.forward_mode.is_idle()
|
||||
else 0
|
||||
),
|
||||
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
|
||||
Reference in New Issue
Block a user