Support multi-node DP attention (#2925)
Co-authored-by: dhou-xai <dhou@x.ai>
This commit is contained in:
@@ -26,8 +26,8 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
|
||||
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph`
|
||||
```
|
||||
# Node 0
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 0
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 0
|
||||
|
||||
# Node 1
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --nccl-init sgl-dev-0:50000 --nnodes 2 --node-rank 1
|
||||
python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct --tp 4 --dist-init-addr sgl-dev-0:50000 --nnodes 2 --node-rank 1
|
||||
```
|
||||
|
||||
@@ -11,9 +11,9 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instr
|
||||
```bash
|
||||
# on the first node, replace 172.16.4.52:20000 with your own node ip address and port
|
||||
|
||||
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0
|
||||
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --dist-init-addr-addr 172.16.4.52:20000 --nnodes 2 --node-rank 0
|
||||
|
||||
# on the second node, replace 172.18.45.52:20000 with your own node ip address and port
|
||||
|
||||
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --nccl-init-addr 172.18.45.52:20000 --nnodes 2 --node-rank 1
|
||||
python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-405B-Instruct --tp 16 --dist-init-addr-addr 172.18.45.52:20000 --nnodes 2 --node-rank 1
|
||||
```
|
||||
|
||||
@@ -18,6 +18,7 @@ import triton.language as tl
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
@@ -62,9 +63,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
self.decode_use_tensor_cores = should_use_tensor_core(
|
||||
kv_cache_dtype=model_runner.kv_cache_dtype,
|
||||
num_attention_heads=model_runner.model_config.num_attention_heads
|
||||
// model_runner.tp_size,
|
||||
// get_attention_tp_size(),
|
||||
num_kv_heads=model_runner.model_config.get_num_kv_heads(
|
||||
model_runner.tp_size
|
||||
get_attention_tp_size()
|
||||
),
|
||||
)
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
@@ -147,7 +148,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
self.prefill_cuda_graph_metadata = {}
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
self.indices_updater_decode.update(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
@@ -238,7 +239,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
if forward_mode.is_decode():
|
||||
if forward_mode.is_decode_or_idle():
|
||||
decode_wrappers = []
|
||||
for i in range(self.num_wrappers):
|
||||
decode_wrappers.append(
|
||||
@@ -307,7 +308,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
if forward_mode.is_decode():
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.indices_updater_decode.update(
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
@@ -453,10 +454,10 @@ class FlashInferIndicesUpdaterDecode:
|
||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||
# Parse Constants
|
||||
self.num_qo_heads = (
|
||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
||||
model_runner.tp_size
|
||||
get_attention_tp_size()
|
||||
)
|
||||
self.head_dim = model_runner.model_config.head_dim
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
@@ -625,10 +626,10 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
||||
# Parse Constants
|
||||
self.num_qo_heads = (
|
||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
||||
model_runner.tp_size
|
||||
get_attention_tp_size()
|
||||
)
|
||||
self.head_dim = model_runner.model_config.head_dim
|
||||
self.data_type = model_runner.kv_cache_dtype
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.attention import AttentionBackend
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -28,12 +29,9 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.decode_attention_fwd = decode_attention_fwd
|
||||
self.extend_attention_fwd = extend_attention_fwd
|
||||
|
||||
if model_runner.server_args.enable_dp_attention:
|
||||
self.num_head = model_runner.model_config.num_attention_heads
|
||||
else:
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
)
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
)
|
||||
|
||||
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
||||
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
||||
|
||||
68
python/sglang/srt/layers/dp_attention.py
Normal file
68
python/sglang/srt/layers/dp_attention.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
from vllm.distributed import GroupCoordinator, get_tp_group
|
||||
|
||||
_ATTN_TP_GROUP = None
|
||||
_ATTN_TP_RANK = None
|
||||
_ATTN_TP_SIZE = None
|
||||
_DP_RANK = None
|
||||
_DP_SIZE = None
|
||||
|
||||
|
||||
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
if not enable_dp_attention:
|
||||
return tp_rank, tp_size, 0
|
||||
|
||||
attn_tp_size = tp_size // dp_size
|
||||
dp_rank = tp_rank // attn_tp_size
|
||||
attn_tp_rank = tp_rank % attn_tp_size
|
||||
return attn_tp_rank, attn_tp_size, dp_rank
|
||||
|
||||
|
||||
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
||||
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
||||
|
||||
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
||||
enable_dp_attention, tp_rank, tp_size, dp_size
|
||||
)
|
||||
_DP_SIZE = dp_size
|
||||
|
||||
tp_group = get_tp_group()
|
||||
_ATTN_TP_GROUP = GroupCoordinator(
|
||||
[
|
||||
list(range(head, head + _ATTN_TP_SIZE))
|
||||
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
||||
],
|
||||
tp_rank,
|
||||
torch.distributed.get_backend(tp_group.device_group),
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
False,
|
||||
group_name="attention_tp",
|
||||
)
|
||||
|
||||
|
||||
def get_attention_tp_group():
|
||||
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
|
||||
return _ATTN_TP_GROUP
|
||||
|
||||
|
||||
def get_attention_tp_rank():
|
||||
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
|
||||
return _ATTN_TP_RANK
|
||||
|
||||
|
||||
def get_attention_tp_size():
|
||||
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
|
||||
return _ATTN_TP_SIZE
|
||||
|
||||
|
||||
def get_attention_dp_rank():
|
||||
assert _DP_RANK is not None, "dp attention not initialized!"
|
||||
return _DP_RANK
|
||||
|
||||
|
||||
def get_attention_dp_size():
|
||||
assert _DP_SIZE is not None, "dp attention not initialized!"
|
||||
return _DP_SIZE
|
||||
@@ -133,7 +133,7 @@ class LogitsProcessor(nn.Module):
|
||||
|
||||
# Get the last hidden states and last logits for the next token prediction
|
||||
if (
|
||||
logits_metadata.forward_mode.is_decode()
|
||||
logits_metadata.forward_mode.is_decode_or_idle()
|
||||
or logits_metadata.forward_mode.is_target_verify()
|
||||
):
|
||||
last_index = None
|
||||
|
||||
@@ -23,6 +23,7 @@ import psutil
|
||||
import setproctitle
|
||||
import zmq
|
||||
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.managers.io_struct import (
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
@@ -63,9 +64,10 @@ class DataParallelController:
|
||||
|
||||
# Init inter-process communication
|
||||
self.context = zmq.Context(1 + server_args.dp_size)
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
self.context, zmq.PULL, port_args.scheduler_input_ipc_name
|
||||
)
|
||||
if server_args.node_rank == 0:
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
self.context, zmq.PULL, port_args.scheduler_input_ipc_name
|
||||
)
|
||||
|
||||
# Dispatch method
|
||||
self.round_robin_counter = 0
|
||||
@@ -75,33 +77,47 @@ class DataParallelController:
|
||||
}
|
||||
self.dispatching = dispatch_lookup[self.load_balance_method]
|
||||
|
||||
# Start data parallel workers
|
||||
base_gpu_id = 0
|
||||
# Launch data parallel workers
|
||||
self.scheduler_procs = []
|
||||
self.workers = [None] * server_args.dp_size
|
||||
|
||||
if not server_args.enable_dp_attention:
|
||||
dp_port_args = self.launch_dp_schedulers(server_args, port_args)
|
||||
else:
|
||||
dp_port_args = self.launch_dp_attention_schedulers(server_args, port_args)
|
||||
|
||||
# Only node rank 0 runs the real data parallel controller that dispatches the requests.
|
||||
if server_args.node_rank == 0:
|
||||
for dp_rank in range(server_args.dp_size):
|
||||
self.workers[dp_rank] = get_zmq_socket(
|
||||
self.context,
|
||||
zmq.PUSH,
|
||||
dp_port_args[dp_rank].scheduler_input_ipc_name,
|
||||
)
|
||||
|
||||
def launch_dp_schedulers(self, server_args, port_args):
|
||||
base_gpu_id = 0
|
||||
|
||||
threads = []
|
||||
sockets = []
|
||||
dp_port_args = []
|
||||
for dp_rank in range(server_args.dp_size):
|
||||
tmp_port_args = PortArgs.init_new(server_args)
|
||||
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
||||
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
||||
dp_port_args.append(tmp_port_args)
|
||||
|
||||
if server_args.enable_dp_attention:
|
||||
# Data parallelism resues the tensor parallelism group,
|
||||
# so all dp ranks should use the same nccl port.
|
||||
tmp_port_args.nccl_port = port_args.nccl_port
|
||||
else:
|
||||
# This port is checked free in PortArgs.init_new.
|
||||
# We hold it first so that the next dp worker gets a different port
|
||||
sockets.append(bind_port(tmp_port_args.nccl_port))
|
||||
# This port is checked free in PortArgs.init_new.
|
||||
# We hold it first so that the next dp worker gets a different port
|
||||
sockets.append(bind_port(tmp_port_args.nccl_port))
|
||||
|
||||
# Create a thread for each worker
|
||||
thread = threading.Thread(
|
||||
target=self.launch_worker_func,
|
||||
target=self.launch_tensor_parallel_group,
|
||||
args=(server_args, tmp_port_args, base_gpu_id, dp_rank),
|
||||
)
|
||||
threads.append(thread)
|
||||
base_gpu_id += 1 if server_args.enable_dp_attention else server_args.tp_size
|
||||
base_gpu_id += server_args.tp_size
|
||||
|
||||
# Free all sockets before starting the threads to launch TP workers
|
||||
for sock in sockets:
|
||||
@@ -113,26 +129,14 @@ class DataParallelController:
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
def launch_worker_func(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
base_gpu_id: int,
|
||||
dp_rank: int,
|
||||
):
|
||||
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
||||
return dp_port_args
|
||||
|
||||
launch_func_ = (
|
||||
self.launch_tensor_parallel_process
|
||||
if server_args.enable_dp_attention
|
||||
else self.launch_tensor_parallel_group
|
||||
)
|
||||
self.workers[dp_rank] = launch_func_(
|
||||
server_args,
|
||||
port_args,
|
||||
base_gpu_id,
|
||||
dp_rank,
|
||||
)
|
||||
def launch_dp_attention_schedulers(self, server_args, port_args):
|
||||
self.launch_tensor_parallel_group(server_args, port_args, 0, None)
|
||||
dp_port_args = []
|
||||
for dp_rank in range(server_args.dp_size):
|
||||
dp_port_args.append(PortArgs.init_new(server_args, dp_rank))
|
||||
return dp_port_args
|
||||
|
||||
def launch_tensor_parallel_group(
|
||||
self,
|
||||
@@ -141,8 +145,10 @@ class DataParallelController:
|
||||
base_gpu_id: int,
|
||||
dp_rank: int,
|
||||
):
|
||||
if not server_args.enable_dp_attention:
|
||||
logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
|
||||
|
||||
# Launch tensor parallel scheduler processes
|
||||
scheduler_procs = []
|
||||
scheduler_pipe_readers = []
|
||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
||||
tp_rank_range = range(
|
||||
@@ -150,53 +156,39 @@ class DataParallelController:
|
||||
tp_size_per_node * (server_args.node_rank + 1),
|
||||
)
|
||||
for tp_rank in tp_rank_range:
|
||||
rank_port_args = port_args
|
||||
|
||||
if server_args.enable_dp_attention:
|
||||
# dp attention has different sharding logic
|
||||
_, _, dp_rank = compute_dp_attention_world_info(
|
||||
server_args.enable_dp_attention,
|
||||
tp_rank,
|
||||
server_args.tp_size,
|
||||
server_args.dp_size,
|
||||
)
|
||||
# compute zmq ports for this dp rank
|
||||
rank_port_args = PortArgs.init_new(server_args, dp_rank)
|
||||
# Data parallelism resues the tensor parallelism group,
|
||||
# so all dp ranks should use the same nccl port.
|
||||
rank_port_args.nccl_port = port_args.nccl_port
|
||||
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = server_args.base_gpu_id + base_gpu_id + tp_rank % tp_size_per_node
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
||||
args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
|
||||
)
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
self.scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
|
||||
send_to = get_zmq_socket(
|
||||
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
||||
)
|
||||
|
||||
# Wait for model to finish loading and get max token nums
|
||||
# Wait for model to finish loading
|
||||
scheduler_info = []
|
||||
for i in range(len(scheduler_pipe_readers)):
|
||||
scheduler_info.append(scheduler_pipe_readers[i].recv())
|
||||
|
||||
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
|
||||
|
||||
return send_to
|
||||
|
||||
def launch_tensor_parallel_process(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
base_gpu_id: int,
|
||||
dp_rank: int,
|
||||
):
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = base_gpu_id
|
||||
tp_rank = dp_rank
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
||||
)
|
||||
proc.start()
|
||||
send_to = get_zmq_socket(
|
||||
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
||||
)
|
||||
|
||||
scheduler_info = reader.recv()
|
||||
self.max_total_num_tokens = scheduler_info["max_total_num_tokens"]
|
||||
|
||||
return send_to
|
||||
|
||||
def round_robin_scheduler(self, req):
|
||||
self.workers[self.round_robin_counter].send_pyobj(req)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
||||
@@ -221,8 +213,8 @@ class DataParallelController:
|
||||
):
|
||||
self.dispatching(recv_req)
|
||||
else:
|
||||
# Send other control messages to all workers
|
||||
for worker in self.workers:
|
||||
# Send other control messages to first worker of tp group
|
||||
for worker in self.workers[:: self.server_args.tp_size]:
|
||||
worker.send_pyobj(recv_req)
|
||||
|
||||
|
||||
@@ -240,7 +232,13 @@ def run_data_parallel_controller_process(
|
||||
pipe_writer.send(
|
||||
{"status": "ready", "max_total_num_tokens": controller.max_total_num_tokens}
|
||||
)
|
||||
controller.event_loop()
|
||||
if server_args.node_rank == 0:
|
||||
controller.event_loop()
|
||||
for proc in controller.scheduler_procs:
|
||||
proc.join()
|
||||
logger.error(
|
||||
f"Scheduler or DataParallelController {proc.pid} terminated with {proc.exitcode}"
|
||||
)
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"DataParallelController hit an exception: {traceback}")
|
||||
|
||||
@@ -1003,6 +1003,11 @@ class ScheduleBatch:
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
||||
self.seq_lens_sum = 0
|
||||
self.extend_num_tokens = 0
|
||||
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
||||
self,
|
||||
self.model_config.vocab_size,
|
||||
enable_overlap_schedule=self.enable_overlap,
|
||||
)
|
||||
|
||||
def prepare_for_decode(self):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
@@ -1117,7 +1122,7 @@ class ScheduleBatch:
|
||||
self.spec_info.merge_batch(other.spec_info)
|
||||
|
||||
def get_model_worker_batch(self):
|
||||
if self.forward_mode.is_decode() or self.forward_mode.is_idle():
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
extend_seq_lens = self.extend_lens
|
||||
|
||||
@@ -33,6 +33,7 @@ import zmq
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
@@ -135,7 +136,17 @@ class Scheduler:
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
|
||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||
self.dp_size = server_args.dp_size
|
||||
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
|
||||
compute_dp_attention_world_info(
|
||||
server_args.enable_dp_attention,
|
||||
self.tp_rank,
|
||||
self.tp_size,
|
||||
self.dp_size,
|
||||
)
|
||||
)
|
||||
|
||||
if self.attn_tp_rank == 0:
|
||||
self.recv_from_tokenizer = get_zmq_socket(
|
||||
context, zmq.PULL, port_args.scheduler_input_ipc_name
|
||||
)
|
||||
@@ -244,6 +255,7 @@ class Scheduler:
|
||||
_,
|
||||
) = self.tp_worker.get_worker_info()
|
||||
self.tp_cpu_group = self.tp_worker.get_tp_cpu_group()
|
||||
self.attn_tp_cpu_group = self.tp_worker.get_attention_tp_cpu_group()
|
||||
self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func()
|
||||
global_server_args_dict.update(worker_global_server_args_dict)
|
||||
set_random_seed(self.random_seed)
|
||||
@@ -447,6 +459,10 @@ class Scheduler:
|
||||
self.process_input_requests(recv_reqs)
|
||||
|
||||
batch = self.get_next_batch_to_run()
|
||||
|
||||
if self.server_args.enable_dp_attention: # TODO: simplify this
|
||||
batch = self.prepare_dp_attn_batch(batch)
|
||||
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
@@ -479,7 +495,7 @@ class Scheduler:
|
||||
|
||||
def recv_requests(self) -> List[Req]:
|
||||
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
||||
if self.tp_rank == 0 or self.server_args.enable_dp_attention:
|
||||
if self.attn_tp_rank == 0:
|
||||
recv_reqs = []
|
||||
|
||||
while True:
|
||||
@@ -491,7 +507,40 @@ class Scheduler:
|
||||
else:
|
||||
recv_reqs = None
|
||||
|
||||
if self.tp_size != 1 and not self.server_args.enable_dp_attention:
|
||||
if self.server_args.enable_dp_attention:
|
||||
if self.attn_tp_rank == 0:
|
||||
work_reqs = [
|
||||
req
|
||||
for req in recv_reqs
|
||||
if isinstance(
|
||||
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
||||
)
|
||||
]
|
||||
control_reqs = [
|
||||
req
|
||||
for req in recv_reqs
|
||||
if not isinstance(
|
||||
req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
|
||||
)
|
||||
]
|
||||
else:
|
||||
work_reqs = None
|
||||
control_reqs = None
|
||||
|
||||
if self.attn_tp_size != 1:
|
||||
attn_tp_rank_0 = self.dp_rank * self.attn_tp_size
|
||||
work_reqs = broadcast_pyobj(
|
||||
work_reqs,
|
||||
self.attn_tp_rank,
|
||||
self.attn_tp_cpu_group,
|
||||
src=attn_tp_rank_0,
|
||||
)
|
||||
if self.tp_size != 1:
|
||||
control_reqs = broadcast_pyobj(
|
||||
control_reqs, self.tp_rank, self.tp_cpu_group
|
||||
)
|
||||
recv_reqs = work_reqs + control_reqs
|
||||
elif self.tp_size != 1:
|
||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
||||
return recv_reqs
|
||||
|
||||
@@ -887,7 +936,7 @@ class Scheduler:
|
||||
self.being_chunked_req.is_being_chunked += 1
|
||||
|
||||
# Print stats
|
||||
if self.tp_rank == 0:
|
||||
if self.attn_tp_rank == 0:
|
||||
self.log_prefill_stats(adder, can_run_list, running_bs, has_being_chunked)
|
||||
|
||||
# Create a new batch
|
||||
@@ -974,7 +1023,7 @@ class Scheduler:
|
||||
self.forward_ct += 1
|
||||
|
||||
if self.is_generation:
|
||||
if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0:
|
||||
if batch.forward_mode.is_decode_or_idle() or batch.extend_num_tokens != 0:
|
||||
if self.spec_algorithm.is_none():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
logits_output, next_token_ids = (
|
||||
@@ -988,18 +1037,8 @@ class Scheduler:
|
||||
num_accepted_tokens,
|
||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
elif batch.forward_mode.is_idle():
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
self.tp_worker.forward_batch_idle(model_worker_batch)
|
||||
return
|
||||
else:
|
||||
logits_output = None
|
||||
if self.skip_tokenizer_init:
|
||||
next_token_ids = torch.full(
|
||||
(batch.batch_size(),), self.tokenizer.eos_token_id
|
||||
)
|
||||
else:
|
||||
next_token_ids = torch.full((batch.batch_size(),), 0)
|
||||
assert False, "batch.extend_num_tokens == 0, this is unexpected!"
|
||||
batch.output_ids = next_token_ids
|
||||
ret = logits_output, next_token_ids, model_worker_batch.bid
|
||||
else: # embedding or reward model
|
||||
@@ -1016,6 +1055,9 @@ class Scheduler:
|
||||
self.running_batch = None
|
||||
elif batch.forward_mode.is_extend():
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
elif batch.forward_mode.is_idle():
|
||||
if self.enable_overlap:
|
||||
self.tp_worker.resolve_batch_result(result[-1])
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
@@ -1166,7 +1208,7 @@ class Scheduler:
|
||||
|
||||
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
||||
if (
|
||||
self.tp_rank == 0
|
||||
self.attn_tp_rank == 0
|
||||
and self.forward_ct_decode % self.server_args.decode_log_interval == 0
|
||||
):
|
||||
self.log_decode_stats()
|
||||
@@ -1402,12 +1444,7 @@ class Scheduler:
|
||||
# Check forward mode for cuda graph
|
||||
if not self.server_args.disable_cuda_graph:
|
||||
forward_mode_state = torch.tensor(
|
||||
(
|
||||
1
|
||||
if local_batch.forward_mode.is_decode()
|
||||
or local_batch.forward_mode.is_idle()
|
||||
else 0
|
||||
),
|
||||
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
|
||||
@@ -101,6 +101,7 @@ class TpModelWorker:
|
||||
self.max_total_num_tokens // 2
|
||||
if server_args.max_running_requests is None
|
||||
else server_args.max_running_requests
|
||||
// (server_args.dp_size if server_args.enable_dp_attention else 1)
|
||||
),
|
||||
self.model_runner.req_to_token_pool.size,
|
||||
)
|
||||
@@ -142,16 +143,15 @@ class TpModelWorker:
|
||||
def get_tp_cpu_group(self):
|
||||
return self.model_runner.tp_group.cpu_group
|
||||
|
||||
def get_attention_tp_cpu_group(self):
|
||||
return self.model_runner.attention_tp_group.cpu_group
|
||||
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.model_runner.req_to_token_pool,
|
||||
self.model_runner.token_to_kv_pool,
|
||||
)
|
||||
|
||||
def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
self.model_runner.forward(forward_batch)
|
||||
|
||||
def forward_batch_generation(
|
||||
self,
|
||||
model_worker_batch: ModelWorkerBatch,
|
||||
|
||||
@@ -92,6 +92,9 @@ class TpModelWorkerClient:
|
||||
def get_tp_cpu_group(self):
|
||||
return self.worker.get_tp_cpu_group()
|
||||
|
||||
def get_attention_tp_cpu_group(self):
|
||||
return self.worker.get_attention_tp_cpu_group()
|
||||
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.worker.model_runner.req_to_token_pool,
|
||||
|
||||
@@ -122,6 +122,7 @@ class CudaGraphRunner:
|
||||
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
self.dp_size = self.model_runner.server_args.dp_size
|
||||
|
||||
# Batch sizes to capture
|
||||
self.capture_bs = self.model_runner.server_args.cuda_graph_bs
|
||||
@@ -218,7 +219,7 @@ class CudaGraphRunner:
|
||||
if self.enable_dp_attention:
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.tp_size,
|
||||
self.max_bs * self.dp_size,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
|
||||
@@ -35,6 +35,10 @@ from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttn
|
||||
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_group,
|
||||
initialize_dp_attention,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import Sampler
|
||||
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
||||
@@ -235,11 +239,18 @@ class ModelRunner:
|
||||
distributed_init_method=dist_init_method,
|
||||
)
|
||||
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
||||
initialize_dp_attention(
|
||||
enable_dp_attention=self.server_args.enable_dp_attention,
|
||||
tp_rank=self.tp_rank,
|
||||
tp_size=self.tp_size,
|
||||
dp_size=self.server_args.dp_size,
|
||||
)
|
||||
|
||||
min_per_gpu_memory = get_available_gpu_memory(
|
||||
self.device, self.gpu_id, distributed=self.tp_size > 1
|
||||
)
|
||||
self.tp_group = get_tp_group()
|
||||
self.attention_tp_group = get_attention_tp_group()
|
||||
|
||||
# Check memory for tensor parallelism
|
||||
if self.tp_size > 1:
|
||||
|
||||
@@ -855,10 +855,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
forward_batch: ForwardBatch,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, forward_batch)
|
||||
if not forward_batch.forward_mode.is_idle():
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
return self.logits_processor(
|
||||
input_ids, hidden_states, self.lm_head, forward_batch
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
|
||||
@@ -239,15 +239,14 @@ class ServerArgs:
|
||||
|
||||
# Others
|
||||
if self.enable_dp_attention:
|
||||
assert self.tp_size % self.dp_size == 0
|
||||
self.dp_size = self.tp_size
|
||||
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
||||
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning(
|
||||
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
||||
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
|
||||
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
||||
"Overlap scheduler is disabled."
|
||||
)
|
||||
|
||||
# Speculative Decoding
|
||||
@@ -880,8 +879,8 @@ class ServerArgs:
|
||||
self.tp_size % self.nnodes == 0
|
||||
), "tp_size must be divisible by number of nodes"
|
||||
assert not (
|
||||
self.dp_size > 1 and self.nnodes != 1
|
||||
), "multi-node data parallel is not supported"
|
||||
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
|
||||
), "multi-node data parallel is not supported unless dp attention!"
|
||||
assert (
|
||||
self.max_loras_per_batch > 0
|
||||
# FIXME
|
||||
@@ -919,6 +918,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
return server_args
|
||||
|
||||
|
||||
ZMQ_TCP_PORT_DELTA = 233
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PortArgs:
|
||||
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
|
||||
@@ -932,7 +934,7 @@ class PortArgs:
|
||||
nccl_port: int
|
||||
|
||||
@staticmethod
|
||||
def init_new(server_args) -> "PortArgs":
|
||||
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
||||
port = server_args.port + random.randint(100, 1000)
|
||||
while True:
|
||||
if is_port_available(port):
|
||||
@@ -942,12 +944,39 @@ class PortArgs:
|
||||
else:
|
||||
port -= 43
|
||||
|
||||
return PortArgs(
|
||||
tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||
scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||
detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
|
||||
nccl_port=port,
|
||||
)
|
||||
if not server_args.enable_dp_attention:
|
||||
# Normal case, use IPC within a single node
|
||||
return PortArgs(
|
||||
tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
||||
nccl_port=port,
|
||||
)
|
||||
else:
|
||||
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
||||
if server_args.nnodes == 1 and server_args.dist_init_addr is None:
|
||||
dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
|
||||
else:
|
||||
dist_init_addr = server_args.dist_init_addr.split(":")
|
||||
assert (
|
||||
len(dist_init_addr) == 2
|
||||
), "please provide --dist-init-addr as host:port of head node"
|
||||
|
||||
dist_init_host, dist_init_port = dist_init_addr
|
||||
port_base = int(dist_init_port) + 1
|
||||
if dp_rank is None:
|
||||
scheduler_input_port = (
|
||||
port_base + 2
|
||||
) # TokenizerManager to DataParallelController
|
||||
else:
|
||||
scheduler_input_port = port_base + 2 + 1 + dp_rank
|
||||
|
||||
return PortArgs(
|
||||
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
||||
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
||||
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
|
||||
nccl_port=port,
|
||||
)
|
||||
|
||||
|
||||
class LoRAPathAction(argparse.Action):
|
||||
|
||||
@@ -802,11 +802,11 @@ def get_zmq_socket(context: zmq.Context, socket_type: zmq.SocketType, endpoint:
|
||||
if socket_type == zmq.PUSH:
|
||||
socket.setsockopt(zmq.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||
socket.connect(f"ipc://{endpoint}")
|
||||
socket.connect(endpoint)
|
||||
elif socket_type == zmq.PULL:
|
||||
socket.setsockopt(zmq.RCVHWM, 0)
|
||||
socket.setsockopt(zmq.RCVBUF, buf_size)
|
||||
socket.bind(f"ipc://{endpoint}")
|
||||
socket.bind(endpoint)
|
||||
else:
|
||||
raise ValueError(f"Unsupported socket type: {socket_type}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user