Fix triton args init (#1034)
This commit is contained in:
@@ -148,9 +148,6 @@ class InputMetadata:
|
||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||
self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu)
|
||||
|
||||
def init_total_num_tokens(self, batch: ScheduleBatch):
|
||||
self.total_num_tokens = sum(len(req.fill_ids) for req in batch.reqs)
|
||||
|
||||
@classmethod
|
||||
def from_schedule_batch(
|
||||
cls,
|
||||
@@ -174,7 +171,11 @@ class InputMetadata:
|
||||
|
||||
ret.compute_extend_infos(batch)
|
||||
|
||||
ret.init_total_num_tokens(batch)
|
||||
if (
|
||||
forward_mode != ForwardMode.DECODE
|
||||
or model_runner.server_args.disable_flashinfer
|
||||
):
|
||||
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
|
||||
|
||||
if forward_mode != ForwardMode.DECODE:
|
||||
ret.init_multimuldal_info(batch)
|
||||
@@ -203,7 +204,7 @@ class InputMetadata:
|
||||
|
||||
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
self.triton_max_seq_len = max(len(r.fill_ids) for r in batch.reqs)
|
||||
self.triton_max_seq_len = int(torch.max(self.seq_lens))
|
||||
self.triton_prefix_lens = prefix_lens
|
||||
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
||||
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
||||
|
||||
Reference in New Issue
Block a user