diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 6a03ba97c..191bf388f 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -270,6 +270,7 @@ class Batch: prefix_lens: torch.Tensor = None position_ids_offsets: torch.Tensor = None out_cache_loc: torch.Tensor = None + extend_num_tokens: int = None # For processing logprobs return_logprob: bool = False @@ -280,10 +281,6 @@ class Batch: image_sizes: List[List[int]] = None image_offsets: List[int] = None - # Other arguments for control - output_ids: torch.Tensor = None - extend_num_tokens: int = None - # Batched sampling params temperatures: torch.Tensor = None top_ps: torch.Tensor = None @@ -820,6 +817,7 @@ def init_flashinfer_args( prefix_lens, flashinfer_decode_wrapper, ): + """Init auxiliary variables for FlashInfer attention backend.""" num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size) head_dim = model_runner.model_config.head_dim @@ -885,6 +883,7 @@ def init_flashinfer_args( def init_triton_args(forward_mode, seq_lens, prefix_lens): + """Init auxiliary variables for triton attention backend.""" batch_size = len(seq_lens) max_seq_len = int(torch.max(seq_lens)) start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")