Perormance: Enable cuda graph for dp idle batch (#7269)
Co-authored-by: austindeng <austindeng@tencent.com> Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -1704,14 +1704,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||||
metadata_expand = self.target_verify_metadata_topk_expand[bs]
|
metadata_expand = self.target_verify_metadata_topk_expand[bs]
|
||||||
|
|
||||||
# metadata_expand.max_seq_len_q = 1, already set in capture
|
# metadata_expand.max_seq_len_q = 1, already set in capture
|
||||||
# metadata_expand.cu_seqlens_q already set in capture
|
# metadata_expand.cu_seqlens_q already set in capture
|
||||||
|
|
||||||
offsets = torch.arange(
|
offsets = torch.arange(
|
||||||
self.speculative_num_draft_tokens, device=device
|
self.speculative_num_draft_tokens, device=device
|
||||||
).unsqueeze(
|
).unsqueeze(
|
||||||
0
|
0
|
||||||
) # shape: (1, self.speculative_num_draft_tokens)
|
) # shape: (1, self.speculative_num_draft_tokens)
|
||||||
|
|
||||||
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
|
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
|
||||||
cum_len = torch.nn.functional.pad(
|
cum_len = torch.nn.functional.pad(
|
||||||
torch.cumsum(
|
torch.cumsum(
|
||||||
@@ -1728,17 +1729,20 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
).view(1, -1)
|
).view(1, -1)
|
||||||
# avoid extracting padded seq indices which will be out of boundary
|
# avoid extracting padded seq indices which will be out of boundary
|
||||||
mask_extraction_indices[
|
mask_extraction_indices[
|
||||||
:, spec_info.positions.numel() * self.speculative_num_draft_tokens :
|
:,
|
||||||
|
spec_info.positions.numel() * self.speculative_num_draft_tokens :,
|
||||||
].fill_(0)
|
].fill_(0)
|
||||||
|
|
||||||
mask = spec_info.custom_mask[mask_extraction_indices].view(
|
mask = spec_info.custom_mask[mask_extraction_indices].view(
|
||||||
-1, self.speculative_num_draft_tokens
|
-1, self.speculative_num_draft_tokens
|
||||||
) # (bsz * draft_num, draft_num)
|
) # (bsz * draft_num, draft_num)
|
||||||
|
|
||||||
col_indices = offsets.expand(
|
col_indices = offsets.expand(
|
||||||
mask.shape[0], self.speculative_num_draft_tokens
|
mask.shape[0], self.speculative_num_draft_tokens
|
||||||
)
|
)
|
||||||
keys = torch.where(
|
keys = torch.where(
|
||||||
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
mask,
|
||||||
|
col_indices,
|
||||||
|
col_indices + self.speculative_num_draft_tokens,
|
||||||
)
|
)
|
||||||
_, sort_order = torch.sort(keys, dim=1)
|
_, sort_order = torch.sort(keys, dim=1)
|
||||||
|
|
||||||
@@ -1747,6 +1751,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
.gather(1, cols)
|
.gather(1, cols)
|
||||||
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
||||||
) # (bsz, draft_num)
|
) # (bsz, draft_num)
|
||||||
|
|
||||||
metadata_expand.page_table.copy_(
|
metadata_expand.page_table.copy_(
|
||||||
non_masked_page_table.gather(1, sort_order)
|
non_masked_page_table.gather(1, sort_order)
|
||||||
)
|
)
|
||||||
@@ -1758,6 +1763,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif forward_mode.is_draft_extend():
|
elif forward_mode.is_draft_extend():
|
||||||
metadata = self.draft_extend_metadata[bs]
|
metadata = self.draft_extend_metadata[bs]
|
||||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||||
@@ -1767,7 +1773,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
||||||
)
|
)
|
||||||
accept_length = spec_info.accept_length[:bs]
|
accept_length = spec_info.accept_length[:bs]
|
||||||
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
if spec_info.accept_length_cpu:
|
||||||
|
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
||||||
|
else:
|
||||||
|
metadata.max_seq_len_q = 1
|
||||||
|
|
||||||
metadata.cu_seqlens_q[1:].copy_(
|
metadata.cu_seqlens_q[1:].copy_(
|
||||||
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1821,11 +1821,6 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
can_cuda_graph = 0
|
can_cuda_graph = 0
|
||||||
|
|
||||||
if not spec_algorithm.is_none():
|
|
||||||
# TODO(sang): Support cuda graph when idle batch is there.
|
|
||||||
if local_batch is None or local_batch.forward_mode.is_idle():
|
|
||||||
can_cuda_graph = 0
|
|
||||||
|
|
||||||
is_extend_in_batch = (
|
is_extend_in_batch = (
|
||||||
local_batch.forward_mode.is_extend() if local_batch else False
|
local_batch.forward_mode.is_extend() if local_batch else False
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -306,28 +306,30 @@ class CudaGraphRunner:
|
|||||||
self.encoder_lens = None
|
self.encoder_lens = None
|
||||||
|
|
||||||
if self.require_gathered_buffer:
|
if self.require_gathered_buffer:
|
||||||
|
self.gathered_buffer = torch.zeros(
|
||||||
|
(
|
||||||
|
self.max_num_token,
|
||||||
|
self.model_runner.model_config.hidden_size,
|
||||||
|
),
|
||||||
|
dtype=self.model_runner.dtype,
|
||||||
|
)
|
||||||
if self.require_mlp_tp_gather:
|
if self.require_mlp_tp_gather:
|
||||||
self.gathered_buffer = torch.zeros(
|
|
||||||
(
|
|
||||||
self.max_bs * self.dp_size * self.num_tokens_per_bs,
|
|
||||||
self.model_runner.model_config.hidden_size,
|
|
||||||
),
|
|
||||||
dtype=self.model_runner.dtype,
|
|
||||||
)
|
|
||||||
self.global_num_tokens_gpu = torch.zeros(
|
self.global_num_tokens_gpu = torch.zeros(
|
||||||
(self.dp_size,), dtype=torch.int32
|
(self.dp_size,), dtype=torch.int32
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert self.require_attn_tp_gather
|
assert self.require_attn_tp_gather
|
||||||
self.gathered_buffer = torch.zeros(
|
|
||||||
(
|
|
||||||
self.max_bs * self.num_tokens_per_bs,
|
|
||||||
self.model_runner.model_config.hidden_size,
|
|
||||||
),
|
|
||||||
dtype=self.model_runner.dtype,
|
|
||||||
)
|
|
||||||
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
||||||
|
|
||||||
|
self.custom_mask = torch.ones(
|
||||||
|
(
|
||||||
|
(self.seq_lens.sum().item() + self.max_num_token)
|
||||||
|
* self.num_tokens_per_bs
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
# Capture
|
# Capture
|
||||||
try:
|
try:
|
||||||
with model_capture_mode():
|
with model_capture_mode():
|
||||||
@@ -674,11 +676,12 @@ class CudaGraphRunner:
|
|||||||
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
self.num_token_non_padded.copy_(forward_batch.num_token_non_padded)
|
||||||
if self.enable_two_batch_overlap:
|
if self.enable_two_batch_overlap:
|
||||||
self.tbo_plugin.replay_prepare(
|
self.tbo_plugin.replay_prepare(
|
||||||
forward_mode=forward_batch.forward_mode,
|
forward_mode=self.capture_forward_mode,
|
||||||
bs=bs,
|
bs=bs,
|
||||||
num_token_non_padded=len(forward_batch.input_ids),
|
num_token_non_padded=len(forward_batch.input_ids),
|
||||||
)
|
)
|
||||||
|
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
|
||||||
|
forward_batch.spec_info.custom_mask = self.custom_mask
|
||||||
# Attention backend
|
# Attention backend
|
||||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
@@ -686,7 +689,7 @@ class CudaGraphRunner:
|
|||||||
self.seq_lens[:bs],
|
self.seq_lens[:bs],
|
||||||
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
||||||
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
||||||
forward_batch.forward_mode,
|
self.capture_forward_mode,
|
||||||
forward_batch.spec_info,
|
forward_batch.spec_info,
|
||||||
seq_lens_cpu=self.seq_lens_cpu[:bs],
|
seq_lens_cpu=self.seq_lens_cpu[:bs],
|
||||||
)
|
)
|
||||||
@@ -736,11 +739,7 @@ class CudaGraphRunner:
|
|||||||
else:
|
else:
|
||||||
spec_info = EagleVerifyInput(
|
spec_info = EagleVerifyInput(
|
||||||
draft_token=None,
|
draft_token=None,
|
||||||
custom_mask=torch.ones(
|
custom_mask=self.custom_mask,
|
||||||
(num_tokens * self.model_runner.model_config.context_len),
|
|
||||||
dtype=torch.bool,
|
|
||||||
device="cuda",
|
|
||||||
),
|
|
||||||
positions=None,
|
positions=None,
|
||||||
retrive_index=None,
|
retrive_index=None,
|
||||||
retrive_next_token=None,
|
retrive_next_token=None,
|
||||||
|
|||||||
@@ -99,6 +99,8 @@ class EagleDraftInput:
|
|||||||
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
||||||
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
||||||
capture_hidden_mode=capture_hidden_mode,
|
capture_hidden_mode=capture_hidden_mode,
|
||||||
|
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
|
||||||
|
accept_length_cpu=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_extend_after_decode(
|
def prepare_extend_after_decode(
|
||||||
|
|||||||
@@ -322,13 +322,11 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
||||||
self.verify(batch, spec_info)
|
self.verify(batch, spec_info)
|
||||||
)
|
)
|
||||||
need_forward, can_run_draft_extend_cuda_graph = (
|
|
||||||
self.check_forward_draft_extend_after_decode(batch)
|
if self.check_forward_draft_extend_after_decode(batch):
|
||||||
)
|
|
||||||
if need_forward:
|
|
||||||
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
||||||
self.forward_draft_extend_after_decode(
|
self.forward_draft_extend_after_decode(
|
||||||
batch, can_run_draft_extend_cuda_graph
|
batch,
|
||||||
)
|
)
|
||||||
return (
|
return (
|
||||||
logits_output,
|
logits_output,
|
||||||
@@ -344,7 +342,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
and batch.spec_info.verified_id.shape[0] > 0
|
and batch.spec_info.verified_id.shape[0] > 0
|
||||||
)
|
)
|
||||||
if not self.server_args.enable_dp_attention:
|
if not self.server_args.enable_dp_attention:
|
||||||
return local_need_forward, True
|
return local_need_forward
|
||||||
|
|
||||||
global_need_forward = torch.tensor(
|
global_need_forward = torch.tensor(
|
||||||
[
|
[
|
||||||
@@ -357,10 +355,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
)
|
)
|
||||||
global_need_forward_cnt = global_need_forward[0].item()
|
global_need_forward_cnt = global_need_forward[0].item()
|
||||||
need_forward = global_need_forward_cnt > 0
|
need_forward = global_need_forward_cnt > 0
|
||||||
can_run_draft_extend_cuda_graph = (
|
return need_forward
|
||||||
global_need_forward_cnt == get_tensor_model_parallel_world_size()
|
|
||||||
)
|
|
||||||
return need_forward, can_run_draft_extend_cuda_graph
|
|
||||||
|
|
||||||
def forward_target_extend(
|
def forward_target_extend(
|
||||||
self, batch: ScheduleBatch
|
self, batch: ScheduleBatch
|
||||||
@@ -816,15 +811,12 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
assert forward_batch.spec_info is batch.spec_info
|
assert forward_batch.spec_info is batch.spec_info
|
||||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||||
|
|
||||||
def forward_draft_extend_after_decode(
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||||
self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool
|
|
||||||
):
|
|
||||||
# Backup fields that will be modified in-place
|
# Backup fields that will be modified in-place
|
||||||
seq_lens_backup = batch.seq_lens.clone()
|
seq_lens_backup = batch.seq_lens.clone()
|
||||||
req_pool_indices_backup = batch.req_pool_indices
|
req_pool_indices_backup = batch.req_pool_indices
|
||||||
accept_length_backup = batch.spec_info.accept_length
|
accept_length_backup = batch.spec_info.accept_length
|
||||||
return_logprob_backup = batch.return_logprob
|
return_logprob_backup = batch.return_logprob
|
||||||
|
|
||||||
input_is_idle = batch.forward_mode.is_idle()
|
input_is_idle = batch.forward_mode.is_idle()
|
||||||
if not input_is_idle:
|
if not input_is_idle:
|
||||||
# Prepare metadata
|
# Prepare metadata
|
||||||
@@ -836,14 +828,18 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
else:
|
else:
|
||||||
batch = batch.copy()
|
batch = batch.copy()
|
||||||
batch.prepare_for_idle()
|
batch.prepare_for_idle()
|
||||||
|
hidden_size = (
|
||||||
|
self.model_config.hidden_size * 3
|
||||||
|
if self.speculative_algorithm.is_eagle3()
|
||||||
|
else self.model_config.hidden_size
|
||||||
|
)
|
||||||
batch.spec_info = EagleDraftInput.create_idle_input(
|
batch.spec_info = EagleDraftInput.create_idle_input(
|
||||||
device=self.device,
|
device=self.device,
|
||||||
hidden_size=self.model_config.hidden_size,
|
hidden_size=hidden_size,
|
||||||
dtype=self.model_config.dtype,
|
dtype=self.model_config.dtype,
|
||||||
topk=self.topk,
|
topk=self.topk,
|
||||||
capture_hidden_mode=CaptureHiddenMode.LAST,
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
||||||
)
|
)
|
||||||
|
|
||||||
batch.return_hidden_states = False
|
batch.return_hidden_states = False
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
|
||||||
@@ -858,8 +854,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
|
|
||||||
# Run
|
# Run
|
||||||
can_cuda_graph = (
|
can_cuda_graph = (
|
||||||
can_run_draft_extend_cuda_graph
|
self.cuda_graph_runner_for_draft_extend
|
||||||
and self.cuda_graph_runner_for_draft_extend
|
|
||||||
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
|
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
|
||||||
)
|
)
|
||||||
if can_cuda_graph:
|
if can_cuda_graph:
|
||||||
|
|||||||
Reference in New Issue
Block a user