Support cuda graph for DP attention (#2061)
This commit is contained in:
@@ -337,7 +337,7 @@ class Scheduler:
|
||||
|
||||
kill_parent_process()
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def event_loop_normal(self):
|
||||
"""A normal blocking scheduler loop."""
|
||||
self.last_batch = None
|
||||
@@ -375,7 +375,7 @@ class Scheduler:
|
||||
|
||||
self.last_batch = batch
|
||||
|
||||
@torch.inference_mode()
|
||||
@torch.no_grad()
|
||||
def event_loop_overlap(self):
|
||||
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
|
||||
result_queue = deque()
|
||||
@@ -411,16 +411,12 @@ class Scheduler:
|
||||
else:
|
||||
num_tokens = local_batch.extend_num_tokens
|
||||
|
||||
local_num_tokens = torch.tensor(
|
||||
num_tokens, dtype=torch.int64, device=self.device
|
||||
)
|
||||
global_num_tokens = torch.empty(
|
||||
self.tp_size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
|
||||
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
|
||||
torch.distributed.all_gather_into_tensor(
|
||||
global_num_tokens,
|
||||
local_num_tokens,
|
||||
group=self.tp_worker.get_tp_device_group(),
|
||||
group=self.tp_cpu_group,
|
||||
)
|
||||
|
||||
if local_batch is None and global_num_tokens.max().item() > 0:
|
||||
@@ -429,6 +425,24 @@ class Scheduler:
|
||||
if local_batch is not None:
|
||||
local_batch.global_num_tokens = global_num_tokens.tolist()
|
||||
|
||||
# Check forward mode for cuda graph
|
||||
if not self.server_args.disable_cuda_graph:
|
||||
forward_mode_state = torch.tensor(
|
||||
(
|
||||
1
|
||||
if local_batch.forward_mode.is_decode()
|
||||
or local_batch.forward_mode.is_idle()
|
||||
else 0
|
||||
),
|
||||
dtype=torch.int32,
|
||||
)
|
||||
torch.distributed.all_reduce(
|
||||
forward_mode_state,
|
||||
op=torch.distributed.ReduceOp.MIN,
|
||||
group=self.tp_cpu_group,
|
||||
)
|
||||
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
|
||||
|
||||
return local_batch
|
||||
|
||||
def get_idle_batch(self):
|
||||
|
||||
Reference in New Issue
Block a user