Simplify flashinfer utilities (#1704)
This commit is contained in:
@@ -134,9 +134,7 @@ class ForwardBatch:
|
||||
)
|
||||
|
||||
# Init position information
|
||||
if ret.forward_mode.is_decode():
|
||||
ret.positions = (ret.seq_lens - 1).to(torch.int64)
|
||||
else:
|
||||
if not ret.forward_mode.is_decode():
|
||||
ret.positions = torch.tensor(
|
||||
np.concatenate(
|
||||
[
|
||||
@@ -164,7 +162,6 @@ class ForwardBatch:
|
||||
ret.req_to_token_pool = model_runner.req_to_token_pool
|
||||
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
||||
ret.attn_backend = model_runner.attn_backend
|
||||
model_runner.attn_backend.init_forward_metadata(ret)
|
||||
|
||||
# Init lora information
|
||||
if model_runner.server_args.lora_paths is not None:
|
||||
|
||||
@@ -551,11 +551,14 @@ class ModelRunner:
|
||||
):
|
||||
return self.cuda_graph_runner.replay(forward_batch)
|
||||
|
||||
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
|
||||
self.attn_backend.init_forward_metadata(forward_batch)
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
def forward_extend(self, forward_batch: ForwardBatch):
|
||||
self.attn_backend.init_forward_metadata(forward_batch)
|
||||
if self.is_generation:
|
||||
return self.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
|
||||
Reference in New Issue
Block a user