[CI] Fix broken ci (#2530)
vLLM commit https://github.com/vllm-project/vllm/pull/22711 changed the
encode cache entries logic, this PR adapt the same change for vllm
ascend to make CI happy.
Co-Authored-By: zhoux77899 <zhouxiang100@huawei.com>
- vLLM version: v0.10.1.1
- vLLM main:
0ff902f3b4
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -193,7 +193,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Lazy initialization, these will be set after __init__
|
||||
self.kv_caches: List[torch.Tensor] = []
|
||||
self.encoder_cache: Dict[str, Dict[int, torch.Tensor]] = {}
|
||||
# TODO: remove Dict[str, Dict[int, torch.Tensor]] type after 0.10.1.1
|
||||
self.encoder_cache: Union[Dict[str, Dict[int, torch.Tensor]],
|
||||
Dict[str, torch.Tensor]] = {}
|
||||
self.attn_mask = None
|
||||
self.attn_state = None
|
||||
self.requests: Dict[str, CachedRequestState] = {}
|
||||
@@ -381,7 +383,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Remove finished requests from the cached states.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.requests.pop(req_id, None)
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
if vllm_version_is("0.10.1.1"):
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
# Remove the finished requests from the persistent batch.
|
||||
# NOTE(woosuk): There could be an edge case where finished_req_ids and
|
||||
# scheduled_req_ids overlap. This happens when a request is aborted and
|
||||
@@ -390,15 +393,17 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# and handling the second as a new request.
|
||||
for req_id in scheduler_output.finished_req_ids:
|
||||
self.input_batch.remove_request(req_id)
|
||||
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
encoder_outputs = self.encoder_cache.get(req_id)
|
||||
if encoder_outputs is not None:
|
||||
encoder_outputs.pop(input_id, None)
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
|
||||
if vllm_version_is("0.10.1.1"):
|
||||
# Free the cached encoder outputs.
|
||||
for req_id, input_id in scheduler_output.free_encoder_input_ids:
|
||||
encoder_outputs = self.encoder_cache.get(req_id)
|
||||
if encoder_outputs is not None:
|
||||
encoder_outputs.pop(input_id, None)
|
||||
if not encoder_outputs:
|
||||
self.encoder_cache.pop(req_id, None)
|
||||
else:
|
||||
for mm_hash in scheduler_output.free_encoder_mm_hashes:
|
||||
self.encoder_cache.pop(mm_hash, None)
|
||||
# Remove the unscheduled requests from the persistent batch.
|
||||
# NOTE(woosuk): The unscheduled requests are either preempted requests
|
||||
# or running requests that are not scheduled in this step. We remove
|
||||
@@ -447,6 +452,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
lora_request=new_req_data.lora_request,
|
||||
**({
|
||||
"mm_hashes": new_req_data.mm_hashes
|
||||
} if not vllm_version_is("0.10.1.1") else {
|
||||
"mm_hashes": None
|
||||
}),
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
@@ -882,15 +892,25 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
# Batch the multi-modal inputs.
|
||||
mm_kwargs = list[MultiModalKwargsItem]()
|
||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
||||
if vllm_version_is("0.10.1.1"):
|
||||
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
|
||||
else:
|
||||
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
|
||||
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
|
||||
req_state = self.requests[req_id]
|
||||
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
||||
req_ids_pos.append(
|
||||
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
|
||||
|
||||
if vllm_version_is("0.10.1.1"):
|
||||
for mm_input_id in encoder_input_ids:
|
||||
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
||||
req_ids_pos.append((req_id, mm_input_id,
|
||||
req_state.mm_positions[mm_input_id]))
|
||||
else:
|
||||
for mm_input_id in encoder_input_ids:
|
||||
# TODO remove this assert after 0.10.1.1
|
||||
assert req_state.mm_hashes is not None
|
||||
mm_hash = req_state.mm_hashes[mm_input_id]
|
||||
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
|
||||
mm_hashes_pos.append(
|
||||
(mm_hash, req_state.mm_positions[mm_input_id]))
|
||||
# Batch mm inputs as much as we can: if a request in the batch has
|
||||
# multiple modalities or a different modality than the previous one,
|
||||
# we process it separately to preserve item order.
|
||||
@@ -921,19 +941,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
for output in curr_group_outputs:
|
||||
encoder_outputs.append(output)
|
||||
if vllm_version_is("0.10.1.1"):
|
||||
# Cache the encoder outputs.
|
||||
for (req_id, input_id, pos_info), output in zip(
|
||||
req_ids_pos,
|
||||
encoder_outputs,
|
||||
):
|
||||
if req_id not in self.encoder_cache:
|
||||
self.encoder_cache[req_id] = {}
|
||||
|
||||
# Cache the encoder outputs.
|
||||
for (req_id, input_id, pos_info), output in zip(
|
||||
req_ids_pos,
|
||||
encoder_outputs,
|
||||
):
|
||||
if req_id not in self.encoder_cache:
|
||||
self.encoder_cache[req_id] = {}
|
||||
|
||||
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
|
||||
output,
|
||||
is_embed=pos_info.is_embed,
|
||||
)
|
||||
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
|
||||
output,
|
||||
is_embed=pos_info.is_embed,
|
||||
)
|
||||
else:
|
||||
for (mm_hash, pos_info), output in zip(mm_hashes_pos,
|
||||
encoder_outputs):
|
||||
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
|
||||
output,
|
||||
is_embed=pos_info.is_embed,
|
||||
)
|
||||
|
||||
def _gather_mm_embeddings(
|
||||
self,
|
||||
@@ -946,6 +973,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
num_computed_tokens = req_state.num_computed_tokens
|
||||
mm_positions = req_state.mm_positions
|
||||
if not vllm_version_is("0.10.1.1"):
|
||||
mm_hashes = req_state.mm_hashes
|
||||
for i, pos_info in enumerate(mm_positions):
|
||||
start_pos = pos_info.offset
|
||||
num_encoder_tokens = pos_info.length
|
||||
@@ -963,13 +992,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
continue
|
||||
|
||||
start_idx = max(num_computed_tokens - start_pos, 0)
|
||||
end_idx = min(
|
||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||
num_encoder_tokens)
|
||||
assert start_idx < end_idx
|
||||
assert req_id in self.encoder_cache
|
||||
assert i in self.encoder_cache[req_id]
|
||||
encoder_output = self.encoder_cache[req_id][i]
|
||||
if vllm_version_is("0.10.1.1"):
|
||||
end_idx = min(
|
||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||
num_encoder_tokens)
|
||||
assert start_idx < end_idx
|
||||
assert req_id in self.encoder_cache
|
||||
assert i in self.encoder_cache[req_id]
|
||||
encoder_output = self.encoder_cache[req_id][i]
|
||||
else:
|
||||
end_idx = min(
|
||||
num_computed_tokens - start_pos + num_scheduled_tokens,
|
||||
num_encoder_tokens,
|
||||
)
|
||||
assert start_idx < end_idx
|
||||
# TODO remove this assert after 0.10.1.1
|
||||
assert mm_hashes is not None
|
||||
mm_hash = mm_hashes[i]
|
||||
encoder_output = self.encoder_cache.get(mm_hash, None)
|
||||
assert encoder_output is not None,\
|
||||
f"Encoder cache miss for {mm_hash}."
|
||||
|
||||
if (is_embed := pos_info.is_embed) is not None:
|
||||
is_embed = is_embed[start_idx:end_idx]
|
||||
|
||||
@@ -47,6 +47,8 @@ class CachedRequestState:
|
||||
prompt_token_ids: list[int]
|
||||
mm_kwargs: list[MultiModalKwargsItem]
|
||||
mm_positions: list[PlaceholderRange]
|
||||
# TODO: remove Optional after 0.10.1.1
|
||||
mm_hashes: Optional[list[str]]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
pooling_params: Optional[PoolingParams]
|
||||
generator: Optional[torch.Generator]
|
||||
|
||||
Reference in New Issue
Block a user