[Bugfix] Add func swap_states to fix MLA attention (#1580)
### What this PR does / why we need it? mla attention still using the gpu_input_batch's attr:`swap_states`, which will lead to an error `AttributeError: 'InputBatch' object has no attribute 'swap_states'` This PR fixed the mla input patch error ### How was this patch tested? will be tested by #1136 --------- Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import copy_slice
|
||||
@@ -423,6 +424,64 @@ class InputBatch:
|
||||
self.pooling_params.pop(req_id, None)
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
old_id_i1 = self._req_ids[i1]
|
||||
old_id_i2 = self._req_ids[i2]
|
||||
self._req_ids[i1], self._req_ids[i2] =\
|
||||
self._req_ids[i2], self._req_ids[i1] # noqa
|
||||
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
|
||||
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
|
||||
assert old_id_i1 is not None and old_id_i2 is not None
|
||||
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
|
||||
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
|
||||
self.num_tokens[i1], self.num_tokens[i2] =\
|
||||
self.num_tokens[i2], self.num_tokens[i1]
|
||||
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
|
||||
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
|
||||
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
|
||||
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
|
||||
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
|
||||
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
|
||||
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
|
||||
self.temperature_cpu[i2], self.temperature_cpu[i1]
|
||||
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
|
||||
self.top_p_cpu[i2], self.top_p_cpu[i1]
|
||||
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
|
||||
self.top_k_cpu[i2], self.top_k_cpu[i1]
|
||||
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
|
||||
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
|
||||
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
|
||||
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
|
||||
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
|
||||
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
|
||||
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
|
||||
self.min_p_cpu[i2], self.min_p_cpu[i1]
|
||||
|
||||
# NOTE: the following is unsafe
|
||||
# self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
|
||||
# self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
|
||||
# instead, we need to temporiarily copy the data for one of the indices
|
||||
# TODO(lucas): optimize this by only copying valid indices
|
||||
tmp = self.token_ids_cpu[i1, ...].copy()
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
swap_dict_values(self.generators, i1, i2)
|
||||
swap_dict_values(self.min_tokens, i1, i2)
|
||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
self.logit_bias[i1], self.logit_bias[i2] =\
|
||||
self.logit_bias[i2], self.logit_bias[i1]
|
||||
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2] =\
|
||||
self.allowed_token_ids_mask_cpu_tensor[i2], \
|
||||
self.allowed_token_ids_mask_cpu_tensor[i1]
|
||||
self.block_table.swap_row(i1, i2)
|
||||
|
||||
def condense(self, empty_req_indices: list[int]) -> None:
|
||||
"""Move non-empty requests down into lower, empty indices.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user