Maintain seq_lens_sum to make more FlashInfer operations non-blocking (#1741)
This commit is contained in:
@@ -188,6 +188,7 @@ class CudaGraphRunner:
|
||||
req_pool_indices = self.req_pool_indices[:bs]
|
||||
seq_lens = self.seq_lens[:bs]
|
||||
out_cache_loc = self.out_cache_loc[:bs]
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||
@@ -206,6 +207,7 @@ class CudaGraphRunner:
|
||||
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
||||
attn_backend=self.model_runner.attn_backend,
|
||||
out_cache_loc=out_cache_loc,
|
||||
seq_lens_sum=seq_lens_sum,
|
||||
return_logprob=False,
|
||||
top_logprobs_nums=[0] * bs,
|
||||
positions=torch.clamp((seq_lens - 1), min=0).to(torch.int64),
|
||||
@@ -252,7 +254,10 @@ class CudaGraphRunner:
|
||||
|
||||
# Attention backend
|
||||
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
||||
bs, self.req_pool_indices, self.seq_lens
|
||||
bs,
|
||||
self.req_pool_indices,
|
||||
self.seq_lens,
|
||||
forward_batch.seq_lens_sum,
|
||||
)
|
||||
|
||||
# Replay
|
||||
|
||||
@@ -87,6 +87,9 @@ class ForwardBatch:
|
||||
# The indices of output tokens in the token_to_kv_pool
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
# The sum of all sequence lengths
|
||||
seq_lens_sum: int
|
||||
|
||||
# For logprob
|
||||
return_logprob: bool = False
|
||||
top_logprobs_nums: Optional[List[int]] = None
|
||||
@@ -95,6 +98,7 @@ class ForwardBatch:
|
||||
positions: torch.Tensor = None
|
||||
|
||||
# For extend
|
||||
extend_num_tokens: Optional[int] = None
|
||||
extend_seq_lens: Optional[torch.Tensor] = None
|
||||
extend_prefix_lens: Optional[torch.Tensor] = None
|
||||
extend_start_loc: Optional[torch.Tensor] = None
|
||||
@@ -175,21 +179,6 @@ class ForwardBatch:
|
||||
)
|
||||
self.mrope_positions = self.mrope_positions.to(torch.int64)
|
||||
|
||||
def compute_positions(self, model_runner: ModelRunner, batch: ModelWorkerBatch):
|
||||
device = model_runner.device
|
||||
if self.forward_mode.is_decode():
|
||||
self.positions = (self.seq_lens - 1).to(torch.int64)
|
||||
else:
|
||||
self.positions = torch.concat(
|
||||
[
|
||||
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
||||
for prefix_len, extend_len in zip(
|
||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def init_new(
|
||||
cls,
|
||||
@@ -205,6 +194,7 @@ class ForwardBatch:
|
||||
req_pool_indices=batch.req_pool_indices,
|
||||
seq_lens=batch.seq_lens,
|
||||
out_cache_loc=batch.out_cache_loc,
|
||||
seq_lens_sum=batch.seq_lens_sum,
|
||||
return_logprob=batch.return_logprob,
|
||||
top_logprobs_nums=batch.top_logprobs_nums,
|
||||
lora_paths=batch.lora_paths,
|
||||
@@ -213,7 +203,17 @@ class ForwardBatch:
|
||||
|
||||
# Init position information
|
||||
if not ret.forward_mode.is_decode():
|
||||
ret.positions = torch.concat(
|
||||
[
|
||||
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
||||
for prefix_len, extend_len in zip(
|
||||
batch.extend_prefix_lens, batch.extend_seq_lens
|
||||
)
|
||||
],
|
||||
axis=0,
|
||||
)
|
||||
ret.image_inputs = batch.image_inputs
|
||||
ret.extend_num_tokens = batch.extend_num_tokens
|
||||
ret.extend_seq_lens = torch.tensor(
|
||||
batch.extend_seq_lens, dtype=torch.int32
|
||||
).to(device, non_blocking=True)
|
||||
@@ -225,12 +225,8 @@ class ForwardBatch:
|
||||
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
||||
ret.extend_logprob_start_lens_cpu = batch.extend_logprob_start_lens
|
||||
|
||||
# Init position information
|
||||
is_mrope = model_runner.model_is_mrope
|
||||
if is_mrope:
|
||||
if model_runner.model_is_mrope:
|
||||
ret.compute_mrope_positions(model_runner, batch)
|
||||
else:
|
||||
ret.compute_positions(model_runner, batch)
|
||||
|
||||
# Init attention information
|
||||
ret.req_to_token_pool = model_runner.req_to_token_pool
|
||||
|
||||
Reference in New Issue
Block a user