Fix triton args init (#1034)

This commit is contained in:
Liangsheng Yin
2024-08-11 12:11:26 -07:00
committed by GitHub
parent 4080e82244
commit 7b6a5332ca

View File

@@ -148,9 +148,6 @@ class InputMetadata:
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu)
def init_total_num_tokens(self, batch: ScheduleBatch):
self.total_num_tokens = sum(len(req.fill_ids) for req in batch.reqs)
@classmethod
def from_schedule_batch(
cls,
@@ -174,7 +171,11 @@ class InputMetadata:
ret.compute_extend_infos(batch)
ret.init_total_num_tokens(batch)
if (
forward_mode != ForwardMode.DECODE
or model_runner.server_args.disable_flashinfer
):
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
if forward_mode != ForwardMode.DECODE:
ret.init_multimuldal_info(batch)
@@ -203,7 +204,7 @@ class InputMetadata:
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
"""Init auxiliary variables for triton attention backend."""
self.triton_max_seq_len = max(len(r.fill_ids) for r in batch.reqs)
self.triton_max_seq_len = int(torch.max(self.seq_lens))
self.triton_prefix_lens = prefix_lens
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)