diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index dd2b59728..d54b14ef2 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -148,9 +148,6 @@ class InputMetadata: self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) self.extend_no_prefix = all(x == 0 for x in prefix_lens_cpu) - def init_total_num_tokens(self, batch: ScheduleBatch): - self.total_num_tokens = sum(len(req.fill_ids) for req in batch.reqs) - @classmethod def from_schedule_batch( cls, @@ -174,7 +171,11 @@ class InputMetadata: ret.compute_extend_infos(batch) - ret.init_total_num_tokens(batch) + if ( + forward_mode != ForwardMode.DECODE + or model_runner.server_args.disable_flashinfer + ): + ret.total_num_tokens = int(torch.sum(ret.seq_lens)) if forward_mode != ForwardMode.DECODE: ret.init_multimuldal_info(batch) @@ -203,7 +204,7 @@ class InputMetadata: def init_triton_args(self, batch: ScheduleBatch, prefix_lens): """Init auxiliary variables for triton attention backend.""" - self.triton_max_seq_len = max(len(r.fill_ids) for r in batch.reqs) + self.triton_max_seq_len = int(torch.max(self.seq_lens)) self.triton_prefix_lens = prefix_lens self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0)