Fix triton args init (#1034)
This commit is contained in:
@@ -148,9 +148,6 @@ class InputMetadata:
|
|||||||
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
||||||
self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu)
|
self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu)
|
||||||
|
|
||||||
def init_total_num_tokens(self, batch: ScheduleBatch):
|
|
||||||
self.total_num_tokens = sum(len(req.fill_ids) for req in batch.reqs)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_schedule_batch(
|
def from_schedule_batch(
|
||||||
cls,
|
cls,
|
||||||
@@ -174,7 +171,11 @@ class InputMetadata:
|
|||||||
|
|
||||||
ret.compute_extend_infos(batch)
|
ret.compute_extend_infos(batch)
|
||||||
|
|
||||||
ret.init_total_num_tokens(batch)
|
if (
|
||||||
|
forward_mode != ForwardMode.DECODE
|
||||||
|
or model_runner.server_args.disable_flashinfer
|
||||||
|
):
|
||||||
|
ret.total_num_tokens = int(torch.sum(ret.seq_lens))
|
||||||
|
|
||||||
if forward_mode != ForwardMode.DECODE:
|
if forward_mode != ForwardMode.DECODE:
|
||||||
ret.init_multimuldal_info(batch)
|
ret.init_multimuldal_info(batch)
|
||||||
@@ -203,7 +204,7 @@ class InputMetadata:
|
|||||||
|
|
||||||
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
|
def init_triton_args(self, batch: ScheduleBatch, prefix_lens):
|
||||||
"""Init auxiliary variables for triton attention backend."""
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
self.triton_max_seq_len = max(len(r.fill_ids) for r in batch.reqs)
|
self.triton_max_seq_len = int(torch.max(self.seq_lens))
|
||||||
self.triton_prefix_lens = prefix_lens
|
self.triton_prefix_lens = prefix_lens
|
||||||
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32)
|
||||||
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)
|
||||||
|
|||||||
Reference in New Issue
Block a user