Remove useless variables in infer_batch.py (#651)
This commit is contained in:
@@ -270,6 +270,7 @@ class Batch:
|
||||
prefix_lens: torch.Tensor = None
|
||||
position_ids_offsets: torch.Tensor = None
|
||||
out_cache_loc: torch.Tensor = None
|
||||
extend_num_tokens: int = None
|
||||
|
||||
# For processing logprobs
|
||||
return_logprob: bool = False
|
||||
@@ -280,10 +281,6 @@ class Batch:
|
||||
image_sizes: List[List[int]] = None
|
||||
image_offsets: List[int] = None
|
||||
|
||||
# Other arguments for control
|
||||
output_ids: torch.Tensor = None
|
||||
extend_num_tokens: int = None
|
||||
|
||||
# Batched sampling params
|
||||
temperatures: torch.Tensor = None
|
||||
top_ps: torch.Tensor = None
|
||||
@@ -820,6 +817,7 @@ def init_flashinfer_args(
|
||||
prefix_lens,
|
||||
flashinfer_decode_wrapper,
|
||||
):
|
||||
"""Init auxiliary variables for FlashInfer attention backend."""
|
||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
||||
head_dim = model_runner.model_config.head_dim
|
||||
@@ -885,6 +883,7 @@ def init_flashinfer_args(
|
||||
|
||||
|
||||
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
||||
"""Init auxiliary variables for triton attention backend."""
|
||||
batch_size = len(seq_lens)
|
||||
max_seq_len = int(torch.max(seq_lens))
|
||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||
|
||||
Reference in New Issue
Block a user