Support cuda graph for DP attention (#2061)

This commit is contained in:
Ke Bao
2024-11-18 08:29:20 +08:00
committed by GitHub
parent 11f881d173
commit 62832bb272
9 changed files with 88 additions and 26 deletions

View File

@@ -337,7 +337,7 @@ class Scheduler:
kill_parent_process()
@torch.inference_mode()
@torch.no_grad()
def event_loop_normal(self):
"""A normal blocking scheduler loop."""
self.last_batch = None
@@ -375,7 +375,7 @@ class Scheduler:
self.last_batch = batch
@torch.inference_mode()
@torch.no_grad()
def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue = deque()
@@ -411,16 +411,12 @@ class Scheduler:
else:
num_tokens = local_batch.extend_num_tokens
local_num_tokens = torch.tensor(
num_tokens, dtype=torch.int64, device=self.device
)
global_num_tokens = torch.empty(
self.tp_size, dtype=torch.int64, device=self.device
)
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
group=self.tp_worker.get_tp_device_group(),
group=self.tp_cpu_group,
)
if local_batch is None and global_num_tokens.max().item() > 0:
@@ -429,6 +425,24 @@ class Scheduler:
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()
# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(
1
if local_batch.forward_mode.is_decode()
or local_batch.forward_mode.is_idle()
else 0
),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
return local_batch
def get_idle_batch(self):