Support cuda graph for DP attention (#2061)
This commit is contained in:
@@ -455,6 +455,7 @@ class ScheduleBatch:
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
@@ -891,6 +892,13 @@ class ScheduleBatch:
|
||||
self.seq_lens = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.out_cache_loc = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.req_pool_indices = torch.empty(0, dtype=torch.int32).to(
|
||||
self.device, non_blocking=True
|
||||
)
|
||||
self.seq_lens_sum = 0
|
||||
self.extend_num_tokens = 0
|
||||
|
||||
def prepare_for_decode(self, enable_overlap: bool = False):
|
||||
@@ -1032,6 +1040,7 @@ class ScheduleBatch:
|
||||
return_logprob=self.return_logprob,
|
||||
top_logprobs_nums=self.top_logprobs_nums,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
extend_num_tokens=self.extend_num_tokens,
|
||||
extend_seq_lens=extend_seq_lens,
|
||||
extend_prefix_lens=extend_prefix_lens,
|
||||
@@ -1093,6 +1102,7 @@ class ModelWorkerBatch:
|
||||
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]]
|
||||
can_run_dp_cuda_graph: bool
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int]
|
||||
|
||||
@@ -337,7 +337,7 @@ class Scheduler:
|
||||
|
||||
kill_parent_process()
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def event_loop_normal(self):
|
||||
"""A normal blocking scheduler loop."""
|
||||
self.last_batch = None
|
||||
@@ -375,7 +375,7 @@ class Scheduler:
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap(self):
|
||||
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||
result_queue = deque()
|
||||
@@ -411,16 +411,12 @@ class Scheduler:
|
||||
else:
|
||||
num_tokens = local_batch.extend_num_tokens
|
||||
|
||||
local_num_tokens = torch.tensor(
|
||||
num_tokens, dtype=torch.int64, device=self.device
|
||||
)
|
||||
global_num_tokens = torch.empty(
|
||||
self.tp_size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
||||
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
global_num_tokens,
|
||||
local_num_tokens,
|
||||
group=self.tp_worker.get_tp_device_group(),
|
||||
group=self.tp_cpu_group,
|
||||
)
|
||||
|
||||
if local_batch is None and global_num_tokens.max().item() > 0:
|
||||
@@ -429,6 +425,24 @@ class Scheduler:
|
||||
if local_batch is not None:
|
||||
local_batch.global_num_tokens = global_num_tokens.tolist()
|
||||
|
||||
# Check forward mode for cuda graph
|
||||
if not self.server_args.disable_cuda_graph:
|
||||
forward_mode_state = torch.tensor(
|
||||
(
|
||||
1
|
||||
if local_batch.forward_mode.is_decode()
|
||||
or local_batch.forward_mode.is_idle()
|
||||
else 0
|
||||
),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
forward_mode_state,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_cpu_group,
|
||||
)
|
||||
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
||||
|
||||
return local_batch
|
||||
|
||||
def get_idle_batch(self):
|
||||
|
||||
@@ -128,9 +128,6 @@ class TpModelWorker:
|
||||
def get_tp_cpu_group(self):
|
||||
return self.model_runner.tp_group.cpu_group
|
||||
|
||||
def get_tp_device_group(self):
|
||||
return self.model_runner.tp_group.device_group
|
||||
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.model_runner.req_to_token_pool,
|
||||
|
||||
@@ -83,9 +83,6 @@ class TpModelWorkerClient:
|
||||
def get_tp_cpu_group(self):
|
||||
return self.worker.get_tp_cpu_group()
|
||||
|
||||
def get_tp_device_group(self):
|
||||
return self.worker.get_tp_device_group()
|
||||
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.worker.model_runner.req_to_token_pool,
|
||||
@@ -96,7 +93,7 @@ class TpModelWorkerClient:
|
||||
with torch.cuda.stream(self.forward_stream):
|
||||
self.forward_thread_func_()
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def forward_thread_func_(self):
|
||||
while True:
|
||||
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
||||
|
||||
@@ -111,6 +111,8 @@ class CudaGraphRunner:
|
||||
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
||||
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
||||
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
||||
self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
|
||||
self.tp_size = self.model_runner.tp_size
|
||||
|
||||
# Batch sizes to capture
|
||||
if model_runner.server_args.disable_cuda_graph_padding:
|
||||
@@ -165,6 +167,16 @@ class CudaGraphRunner:
|
||||
else:
|
||||
self.encoder_lens = None
|
||||
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens = [0] * self.tp_size
|
||||
self.gathered_buffer = torch.zeros(
|
||||
(
|
||||
self.max_bs * self.tp_size,
|
||||
self.model_runner.model_config.hidden_size,
|
||||
),
|
||||
dtype=self.model_runner.dtype,
|
||||
)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
with self.model_capture_mode():
|
||||
@@ -190,11 +202,21 @@ class CudaGraphRunner:
|
||||
self.model_runner.model.capture_mode = False
|
||||
|
||||
def can_run(self, forward_batch: ForwardBatch):
|
||||
is_bs_supported = (
|
||||
forward_batch.batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else forward_batch.batch_size <= self.max_bs
|
||||
)
|
||||
if self.enable_dp_attention:
|
||||
min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max(
|
||||
forward_batch.global_num_tokens
|
||||
)
|
||||
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
|
||||
(min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
|
||||
if self.disable_padding
|
||||
else max_num_tokens <= self.max_bs
|
||||
)
|
||||
else:
|
||||
is_bs_supported = (
|
||||
forward_batch.batch_size in self.graphs
|
||||
if self.disable_padding
|
||||
else forward_batch.batch_size <= self.max_bs
|
||||
)
|
||||
|
||||
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
||||
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
||||
@@ -239,6 +261,13 @@ class CudaGraphRunner:
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
mrope_positions = self.mrope_positions[:, :bs]
|
||||
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens[:] = [bs] * self.tp_size
|
||||
gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
|
||||
else:
|
||||
self.global_num_tokens = None
|
||||
gathered_buffer = None
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
@@ -265,6 +294,8 @@ class CudaGraphRunner:
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=clamp_position(seq_lens),
|
||||
mrope_positions=mrope_positions,
|
||||
global_num_tokens=self.global_num_tokens,
|
||||
gathered_buffer=gathered_buffer,
|
||||
)
|
||||
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
||||
return logits_output.next_token_logits
|
||||
@@ -295,7 +326,12 @@ class CudaGraphRunner:
|
||||
raw_bs = forward_batch.batch_size
|
||||
|
||||
# Pad
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
if self.enable_dp_attention:
|
||||
index = bisect.bisect_left(
|
||||
self.capture_bs, max(forward_batch.global_num_tokens)
|
||||
)
|
||||
else:
|
||||
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
||||
bs = self.capture_bs[index]
|
||||
if bs != raw_bs:
|
||||
self.seq_lens.fill_(1)
|
||||
@@ -310,6 +346,8 @@ class CudaGraphRunner:
|
||||
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
||||
if forward_batch.mrope_positions is not None:
|
||||
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
||||
if self.enable_dp_attention:
|
||||
self.global_num_tokens[:] = [bs] * self.tp_size
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
|
||||
@@ -138,6 +138,7 @@ class ForwardBatch:
|
||||
# For DP attention
|
||||
global_num_tokens: Optional[List[int]] = None
|
||||
gathered_buffer: Optional[torch.Tensor] = None
|
||||
can_run_dp_cuda_graph: bool = False
|
||||
|
||||
def compute_mrope_positions(
|
||||
self, model_runner: ModelRunner, batch: ModelWorkerBatch
|
||||
@@ -221,6 +222,7 @@ class ForwardBatch:
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
global_num_tokens=batch.global_num_tokens,
|
||||
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
||||
lora_paths=batch.lora_paths,
|
||||
sampling_info=batch.sampling_info,
|
||||
)
|
||||
|
||||
@@ -592,6 +592,9 @@ class ModelRunner:
|
||||
)
|
||||
|
||||
def forward_idle(self, forward_batch: ForwardBatch):
|
||||
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
@@ -191,11 +191,12 @@ class ServerArgs:
|
||||
if self.enable_dp_attention:
|
||||
self.dp_size = self.tp_size
|
||||
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
||||
self.disable_cuda_graph = True
|
||||
self.cuda_graph_max_bs = min(self.cuda_graph_max_bs, 96)
|
||||
self.enable_overlap_schedule = False
|
||||
logger.warning(
|
||||
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE workload issue. "
|
||||
"The CUDA graph is disabled. Data parallel size is adjust to be the same as tensor parallel size."
|
||||
f"The CUDA graph max batch size is adjusted to {self.cuda_graph_max_bs}. "
|
||||
"Data parallel size is adjusted to be the same as tensor parallel size."
|
||||
)
|
||||
|
||||
if self.enable_overlap_schedule:
|
||||
|
||||
Reference in New Issue
Block a user