Maintain seq_lens_sum to make more FlashInfer operations non-blocking (#1741)

This commit is contained in:
Lianmin Zheng
2024-10-21 01:43:16 -07:00
committed by GitHub
parent cf470fea32
commit 09603c6dc9
8 changed files with 98 additions and 43 deletions

View File

@@ -188,6 +188,7 @@ class CudaGraphRunner:
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:bs]
seq_lens_sum = seq_lens.sum().item()
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
@@ -206,6 +207,7 @@ class CudaGraphRunner:
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens_sum,
return_logprob=False,
top_logprobs_nums=[0] * bs,
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
@@ -252,7 +254,10 @@ class CudaGraphRunner:
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs, self.req_pool_indices, self.seq_lens
bs,
self.req_pool_indices,
self.seq_lens,
forward_batch.seq_lens_sum,
)
# Replay

View File

@@ -87,6 +87,9 @@ class ForwardBatch:
# The indices of output tokens in the token_to_kv_pool
out_cache_loc: torch.Tensor
# The sum of all sequence lengths
seq_lens_sum: int
# For logprob
return_logprob: bool = False
top_logprobs_nums: Optional[List[int]] = None
@@ -95,6 +98,7 @@ class ForwardBatch:
positions: torch.Tensor = None
# For extend
extend_num_tokens: Optional[int] = None
extend_seq_lens: Optional[torch.Tensor] = None
extend_prefix_lens: Optional[torch.Tensor] = None
extend_start_loc: Optional[torch.Tensor] = None
@@ -175,21 +179,6 @@ class ForwardBatch:
)
self.mrope_positions = self.mrope_positions.to(torch.int64)
def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch):
device = model_runner.device
if self.forward_mode.is_decode():
self.positions = (self.seq_lens - 1).to(torch.int64)
else:
self.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)
@classmethod
def init_new(
cls,
@@ -205,6 +194,7 @@ class ForwardBatch:
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
out_cache_loc=batch.out_cache_loc,
seq_lens_sum=batch.seq_lens_sum,
return_logprob=batch.return_logprob,
top_logprobs_nums=batch.top_logprobs_nums,
lora_paths=batch.lora_paths,
@@ -213,7 +203,17 @@ class ForwardBatch:
# Init position information
if not ret.forward_mode.is_decode():
ret.positions = torch.concat(
[
torch.arange(prefix_len, prefix_len + extend_len, device=device)
for prefix_len, extend_len in zip(
batch.extend_prefix_lens, batch.extend_seq_lens
)
],
axis=0,
)
ret.image_inputs = batch.image_inputs
ret.extend_num_tokens = batch.extend_num_tokens
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
).to(device, non_blocking=True)
@@ -225,12 +225,8 @@ class ForwardBatch:
ret.extend_seq_lens_cpu = batch.extend_seq_lens
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
# Init position information
is_mrope = model_runner.model_is_mrope
if is_mrope:
if model_runner.model_is_mrope:
ret.compute_mrope_positions(model_runner, batch)
else:
ret.compute_positions(model_runner, batch)
# Init attention information
ret.req_to_token_pool = model_runner.req_to_token_pool