Remove useless variables in infer_batch.py (#651)
This commit is contained in:
@@ -270,6 +270,7 @@ class Batch:
|
|||||||
prefix_lens: torch.Tensor = None
|
prefix_lens: torch.Tensor = None
|
||||||
position_ids_offsets: torch.Tensor = None
|
position_ids_offsets: torch.Tensor = None
|
||||||
out_cache_loc: torch.Tensor = None
|
out_cache_loc: torch.Tensor = None
|
||||||
|
extend_num_tokens: int = None
|
||||||
|
|
||||||
# For processing logprobs
|
# For processing logprobs
|
||||||
return_logprob: bool = False
|
return_logprob: bool = False
|
||||||
@@ -280,10 +281,6 @@ class Batch:
|
|||||||
image_sizes: List[List[int]] = None
|
image_sizes: List[List[int]] = None
|
||||||
image_offsets: List[int] = None
|
image_offsets: List[int] = None
|
||||||
|
|
||||||
# Other arguments for control
|
|
||||||
output_ids: torch.Tensor = None
|
|
||||||
extend_num_tokens: int = None
|
|
||||||
|
|
||||||
# Batched sampling params
|
# Batched sampling params
|
||||||
temperatures: torch.Tensor = None
|
temperatures: torch.Tensor = None
|
||||||
top_ps: torch.Tensor = None
|
top_ps: torch.Tensor = None
|
||||||
@@ -820,6 +817,7 @@ def init_flashinfer_args(
|
|||||||
prefix_lens,
|
prefix_lens,
|
||||||
flashinfer_decode_wrapper,
|
flashinfer_decode_wrapper,
|
||||||
):
|
):
|
||||||
|
"""Init auxiliary variables for FlashInfer attention backend."""
|
||||||
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
|
||||||
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
|
||||||
head_dim = model_runner.model_config.head_dim
|
head_dim = model_runner.model_config.head_dim
|
||||||
@@ -885,6 +883,7 @@ def init_flashinfer_args(
|
|||||||
|
|
||||||
|
|
||||||
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
def init_triton_args(forward_mode, seq_lens, prefix_lens):
|
||||||
|
"""Init auxiliary variables for triton attention backend."""
|
||||||
batch_size = len(seq_lens)
|
batch_size = len(seq_lens)
|
||||||
max_seq_len = int(torch.max(seq_lens))
|
max_seq_len = int(torch.max(seq_lens))
|
||||||
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
|
||||||
|
|||||||
Reference in New Issue
Block a user