diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 5f8cc0ed4..4e73dd9ae 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -646,12 +646,17 @@ class ForwardBatch: device=model_runner.device, ) - bs = self.batch_size if len(global_num_tokens) > 1: num_tokens = global_num_tokens[get_attention_dp_rank()] else: num_tokens = global_num_tokens[0] + if self.forward_mode.is_decode(): + setattr(self, "raw_bs", self.batch_size) + self.batch_size = num_tokens + + bs = self.batch_size + # padding self.input_ids = self._pad_tensor_to_size(self.input_ids, num_tokens) self.req_pool_indices = self._pad_tensor_to_size(self.req_pool_indices, bs) @@ -659,6 +664,9 @@ class ForwardBatch: seq_len_fill_value = ( model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) + self.seq_lens_sum = self.seq_lens_sum + seq_len_fill_value * ( + bs - self.seq_lens.shape[0] + ) self.seq_lens = self._pad_tensor_to_size( self.seq_lens, bs, value=seq_len_fill_value ) @@ -702,7 +710,7 @@ class ForwardBatch: def post_forward_mlp_sync_batch(self, logits_output: LogitsProcessorOutput): - bs = self.batch_size + bs = getattr(self, "raw_bs", self.batch_size) if self.spec_info is not None: if self.forward_mode.is_decode(): # draft