Simplify flashinfer utilities (#1704)

This commit is contained in:
Lianmin Zheng
2024-10-17 22:54:14 -07:00
committed by GitHub
parent 9e0dac1ad7
commit 6d0fa73ece
8 changed files with 391 additions and 337 deletions

View File

@@ -134,9 +134,7 @@ class ForwardBatch:
)
# Init position information
if ret.forward_mode.is_decode():
ret.positions = (ret.seq_lens - 1).to(torch.int64)
else:
if not ret.forward_mode.is_decode():
ret.positions = torch.tensor(
np.concatenate(
[
@@ -164,7 +162,6 @@ class ForwardBatch:
ret.req_to_token_pool = model_runner.req_to_token_pool
ret.token_to_kv_pool = model_runner.token_to_kv_pool
ret.attn_backend = model_runner.attn_backend
model_runner.attn_backend.init_forward_metadata(ret)
# Init lora information
if model_runner.server_args.lora_paths is not None:

View File

@@ -551,11 +551,14 @@ class ModelRunner:
):
return self.cuda_graph_runner.replay(forward_batch)
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
self.attn_backend.init_forward_metadata(forward_batch)
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch
)
def forward_extend(self, forward_batch: ForwardBatch):
self.attn_backend.init_forward_metadata(forward_batch)
if self.is_generation:
return self.model.forward(
forward_batch.input_ids, forward_batch.positions, forward_batch