[Fix] Fix flashinfer cpu <-> gpu synchronization (#8340)
This commit is contained in:
@@ -66,6 +66,10 @@ class PrefillMetadata:
|
|||||||
# Reuse this workspace buffer across all flashinfer wrappers
|
# Reuse this workspace buffer across all flashinfer wrappers
|
||||||
global_workspace_buffer = None
|
global_workspace_buffer = None
|
||||||
|
|
||||||
|
# Use as a fast path to override the indptr in flashinfer's plan function
|
||||||
|
# This is used to remove some host-to-device copy overhead.
|
||||||
|
global_override_indptr_cpu = None
|
||||||
|
|
||||||
|
|
||||||
class FlashInferAttnBackend(AttentionBackend):
|
class FlashInferAttnBackend(AttentionBackend):
|
||||||
"""Flashinfer attention kernels."""
|
"""Flashinfer attention kernels."""
|
||||||
@@ -205,6 +209,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_cpu,
|
||||||
forward_batch.seq_lens_sum,
|
forward_batch.seq_lens_sum,
|
||||||
decode_wrappers=self.decode_wrappers,
|
decode_wrappers=self.decode_wrappers,
|
||||||
encoder_lens=forward_batch.encoder_lens,
|
encoder_lens=forward_batch.encoder_lens,
|
||||||
@@ -215,6 +220,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_cpu,
|
||||||
forward_batch.seq_lens_sum,
|
forward_batch.seq_lens_sum,
|
||||||
prefix_lens=None,
|
prefix_lens=None,
|
||||||
prefill_wrappers=self.prefill_wrappers_paged,
|
prefill_wrappers=self.prefill_wrappers_paged,
|
||||||
@@ -229,6 +235,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_cpu,
|
||||||
forward_batch.seq_lens_sum,
|
forward_batch.seq_lens_sum,
|
||||||
prefix_lens=None,
|
prefix_lens=None,
|
||||||
prefill_wrappers=self.prefill_wrappers_verify,
|
prefill_wrappers=self.prefill_wrappers_verify,
|
||||||
@@ -252,6 +259,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.seq_lens_cpu,
|
||||||
forward_batch.seq_lens_sum,
|
forward_batch.seq_lens_sum,
|
||||||
prefix_lens,
|
prefix_lens,
|
||||||
prefill_wrappers=self.prefill_wrappers_paged,
|
prefill_wrappers=self.prefill_wrappers_paged,
|
||||||
@@ -327,6 +335,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
|
seq_lens.cpu(), # may add a little overhead in capture stage
|
||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
decode_wrappers=decode_wrappers,
|
decode_wrappers=decode_wrappers,
|
||||||
encoder_lens=encoder_lens,
|
encoder_lens=encoder_lens,
|
||||||
@@ -358,6 +367,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
|
seq_lens.cpu(), # may add a little overhead in capture stage
|
||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
prefix_lens=None,
|
prefix_lens=None,
|
||||||
prefill_wrappers=prefill_wrappers,
|
prefill_wrappers=prefill_wrappers,
|
||||||
@@ -387,6 +397,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
|
seq_lens.cpu(), # may add a little overhead in capture stage
|
||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
prefix_lens=None,
|
prefix_lens=None,
|
||||||
prefill_wrappers=prefill_wrappers,
|
prefill_wrappers=prefill_wrappers,
|
||||||
@@ -414,6 +425,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
req_pool_indices[:bs],
|
req_pool_indices[:bs],
|
||||||
seq_lens[:bs],
|
seq_lens[:bs],
|
||||||
|
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
|
||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
decode_wrappers=self.decode_cuda_graph_metadata[bs],
|
||||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||||
@@ -423,6 +435,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
req_pool_indices[:bs],
|
req_pool_indices[:bs],
|
||||||
seq_lens[:bs],
|
seq_lens[:bs],
|
||||||
|
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
|
||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
prefix_lens=None,
|
prefix_lens=None,
|
||||||
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
||||||
@@ -434,6 +447,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
self.indices_updater_prefill.update(
|
self.indices_updater_prefill.update(
|
||||||
req_pool_indices[:bs],
|
req_pool_indices[:bs],
|
||||||
seq_lens[:bs],
|
seq_lens[:bs],
|
||||||
|
seq_lens_cpu[:bs] if seq_lens_cpu is not None else None,
|
||||||
seq_lens_sum,
|
seq_lens_sum,
|
||||||
prefix_lens=None,
|
prefix_lens=None,
|
||||||
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
||||||
@@ -581,7 +595,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
|
|
||||||
class FlashInferIndicesUpdaterDecode:
|
class FlashInferIndicesUpdaterDecode:
|
||||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
|
||||||
# Parse Constants
|
# Parse Constants
|
||||||
self.num_qo_heads = (
|
self.num_qo_heads = (
|
||||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
@@ -614,6 +628,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
@@ -626,6 +641,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
@@ -640,30 +656,39 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self.kv_indptr[0],
|
self.kv_indptr[0],
|
||||||
None,
|
None,
|
||||||
spec_info,
|
spec_info,
|
||||||
|
seq_lens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_sliding_window(
|
def update_sliding_window(
|
||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
|
assert self.sliding_window_size is not None
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
# Sliding window attention
|
# Sliding window attention
|
||||||
paged_kernel_lens_tmp = torch.minimum( # TODO: replace this with clamp
|
paged_kernel_lens_tmp = torch.clamp(
|
||||||
seq_lens,
|
seq_lens, max=self.sliding_window_size + 1
|
||||||
torch.tensor(self.sliding_window_size + 1),
|
|
||||||
)
|
)
|
||||||
paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
|
if seq_lens_cpu is not None:
|
||||||
|
seq_lens_cpu_tmp = torch.clamp(
|
||||||
|
seq_lens_cpu, max=self.sliding_window_size + 1
|
||||||
|
)
|
||||||
|
paged_kernel_lens_sum_tmp = seq_lens_cpu_tmp.sum().item()
|
||||||
|
else:
|
||||||
|
paged_kernel_lens_sum_tmp = paged_kernel_lens_tmp.sum().item()
|
||||||
kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
|
kv_start_idx_tmp = seq_lens - paged_kernel_lens_tmp
|
||||||
else:
|
else:
|
||||||
# Full attention
|
# Full attention
|
||||||
paged_kernel_lens_tmp = seq_lens
|
paged_kernel_lens_tmp = seq_lens
|
||||||
paged_kernel_lens_sum_tmp = seq_lens_sum
|
paged_kernel_lens_sum_tmp = seq_lens_sum
|
||||||
|
seq_lens_cpu_tmp = seq_lens_cpu
|
||||||
kv_start_idx_tmp = None
|
kv_start_idx_tmp = None
|
||||||
|
|
||||||
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
use_sliding_window_kv_pool = wrapper_id == 0 and isinstance(
|
||||||
@@ -678,6 +703,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self.kv_indptr[wrapper_id],
|
self.kv_indptr[wrapper_id],
|
||||||
kv_start_idx_tmp,
|
kv_start_idx_tmp,
|
||||||
spec_info,
|
spec_info,
|
||||||
|
seq_lens_cpu=seq_lens_cpu_tmp,
|
||||||
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
use_sliding_window_kv_pool=use_sliding_window_kv_pool,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -685,6 +711,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
@@ -709,6 +736,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self.kv_indptr[wrapper_id],
|
self.kv_indptr[wrapper_id],
|
||||||
kv_start_idx,
|
kv_start_idx,
|
||||||
spec_info,
|
spec_info,
|
||||||
|
seq_lens_cpu=seq_lens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
def call_begin_forward(
|
def call_begin_forward(
|
||||||
@@ -720,6 +748,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
kv_indptr: torch.Tensor,
|
kv_indptr: torch.Tensor,
|
||||||
kv_start_idx: torch.Tensor,
|
kv_start_idx: torch.Tensor,
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
use_sliding_window_kv_pool: bool = False,
|
use_sliding_window_kv_pool: bool = False,
|
||||||
):
|
):
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
@@ -756,6 +785,14 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
global global_override_indptr_cpu
|
||||||
|
locally_override = False
|
||||||
|
if seq_lens_cpu is not None and global_override_indptr_cpu is None:
|
||||||
|
locally_override = True
|
||||||
|
global_override_indptr_cpu = torch.empty_like(kv_indptr, device="cpu")
|
||||||
|
global_override_indptr_cpu[0] = 0
|
||||||
|
global_override_indptr_cpu[1 : bs + 1] = torch.cumsum(seq_lens_cpu, dim=0)
|
||||||
|
|
||||||
wrapper.begin_forward(
|
wrapper.begin_forward(
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
@@ -769,9 +806,12 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
non_blocking=True,
|
non_blocking=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if locally_override:
|
||||||
|
global_override_indptr_cpu = None
|
||||||
|
|
||||||
|
|
||||||
class FlashInferIndicesUpdaterPrefill:
|
class FlashInferIndicesUpdaterPrefill:
|
||||||
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
def __init__(self, model_runner: ModelRunner, attn_backend: FlashInferAttnBackend):
|
||||||
# Parse Constants
|
# Parse Constants
|
||||||
self.num_qo_heads = (
|
self.num_qo_heads = (
|
||||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
@@ -806,6 +846,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
@@ -820,6 +861,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
@@ -853,6 +895,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
@@ -898,6 +941,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
prefix_lens: torch.Tensor,
|
prefix_lens: torch.Tensor,
|
||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
@@ -1020,11 +1064,6 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# Use as a fast path to override the indptr in flashinfer's plan function
|
|
||||||
# This is used to remove some host-to-device copy overhead.
|
|
||||||
global global_override_indptr_cpu
|
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMultiStepDraftBackend:
|
class FlashInferMultiStepDraftBackend:
|
||||||
"""
|
"""
|
||||||
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
||||||
@@ -1056,7 +1095,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
self.kv_last_page_len = torch.ones(
|
self.kv_last_page_len = torch.ones(
|
||||||
(max_bs,), dtype=torch.int32, device=model_runner.device
|
(max_bs,), dtype=torch.int32, device=model_runner.device
|
||||||
)
|
)
|
||||||
self.attn_backends = []
|
self.attn_backends: List[FlashInferAttnBackend] = []
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends.append(
|
self.attn_backends.append(
|
||||||
FlashInferAttnBackend(
|
FlashInferAttnBackend(
|
||||||
@@ -1176,7 +1215,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
encoder_lens=None,
|
encoder_lens=None,
|
||||||
forward_mode=ForwardMode.DECODE,
|
forward_mode=ForwardMode.DECODE,
|
||||||
spec_info=forward_batch.spec_info,
|
spec_info=forward_batch.spec_info,
|
||||||
seq_lens_cpu=None,
|
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||||
|
|||||||
@@ -1714,16 +1714,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
||||||
# Create seq_lens_cpu when needed
|
# Create seq_lens_cpu when needed
|
||||||
if (
|
if (
|
||||||
attention_backend_str == "fa3"
|
attention_backend_str
|
||||||
or (
|
in [
|
||||||
global_server_args_dict["use_mla_backend"]
|
"fa3",
|
||||||
and attention_backend_str == "flashinfer"
|
"flashinfer",
|
||||||
)
|
"flashmla",
|
||||||
or attention_backend_str == "flashmla"
|
"cutlass_mla",
|
||||||
or attention_backend_str == "cutlass_mla"
|
"ascend",
|
||||||
or attention_backend_str == "ascend"
|
"trtllm_mha",
|
||||||
or attention_backend_str == "trtllm_mha"
|
"aiter",
|
||||||
or attention_backend_str == "aiter"
|
]
|
||||||
or global_server_args_dict["enable_two_batch_overlap"]
|
or global_server_args_dict["enable_two_batch_overlap"]
|
||||||
):
|
):
|
||||||
seq_lens_cpu = (
|
seq_lens_cpu = (
|
||||||
|
|||||||
@@ -729,10 +729,12 @@ class CudaGraphRunner:
|
|||||||
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
||||||
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
||||||
|
|
||||||
|
seq_lens_cpu = None
|
||||||
if forward_batch.seq_lens_cpu is not None:
|
if forward_batch.seq_lens_cpu is not None:
|
||||||
if bs != raw_bs:
|
if bs != raw_bs:
|
||||||
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
||||||
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
||||||
|
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
||||||
|
|
||||||
if pp_proxy_tensors:
|
if pp_proxy_tensors:
|
||||||
for key in self.pp_proxy_tensors.keys():
|
for key in self.pp_proxy_tensors.keys():
|
||||||
@@ -766,7 +768,7 @@ class CudaGraphRunner:
|
|||||||
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
||||||
self.capture_forward_mode,
|
self.capture_forward_mode,
|
||||||
forward_batch.spec_info,
|
forward_batch.spec_info,
|
||||||
seq_lens_cpu=self.seq_lens_cpu[:bs],
|
seq_lens_cpu=seq_lens_cpu,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store fields
|
# Store fields
|
||||||
|
|||||||
Reference in New Issue
Block a user