From f22077daa6a32e1d5c5cfe0e84da2cea1ab8cafb Mon Sep 17 00:00:00 2001 From: wangxiyuan Date: Wed, 27 Aug 2025 09:22:01 +0800 Subject: [PATCH] [Embedding] Recover embedding function (#2483) Fix broken embedding function. It's broken by http://github.com/vllm-project/vllm/pull/23162 - vLLM version: v0.10.1.1 - vLLM main: https://github.com/vllm-project/vllm/commit/efc88cf64a399f5459cd6256223e99672c13614d Signed-off-by: wangxiyuan --- tests/e2e/singlecard/test_embedding.py | 3 - vllm_ascend/worker/model_runner_v1.py | 161 +++++++++++++++++-------- vllm_ascend/worker/npu_input_batch.py | 23 ++-- 3 files changed, 129 insertions(+), 58 deletions(-) diff --git a/tests/e2e/singlecard/test_embedding.py b/tests/e2e/singlecard/test_embedding.py index 1fc594e..2868dc2 100644 --- a/tests/e2e/singlecard/test_embedding.py +++ b/tests/e2e/singlecard/test_embedding.py @@ -19,7 +19,6 @@ from collections.abc import Sequence from typing import Optional -import pytest from modelscope import snapshot_download # type: ignore[import-untyped] from tests.e2e.conftest import HfRunner @@ -50,8 +49,6 @@ def test_dummy(): assert True -@pytest.mark.skip( - reason="TODO: revert me when pooler is adapted with the latest vllm main") def test_embed_models_correctness(hf_runner, vllm_runner): queries = ['What is the capital of China?', 'Explain gravity'] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 4dc186f..8fe840a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1545,7 +1545,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): hidden_states, attn_metadata) return draft_token_ids - def _pool( + def _pool_v010( self, hidden_states: torch.Tensor, num_scheduled_tokens: int, @@ -1579,29 +1579,61 @@ class NPUModelRunner(LoRAModelRunnerMixin): else: pooler_output.append(None) extra_args = ({"kv_connector_output": kv_connector_output}) - if vllm_version_is("0.10.1.1"): - modelrunner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=[], - spec_token_ids=None, - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=pooler_output, - **extra_args, - ) - else: - modelrunner_output = ModelRunnerOutput( - req_ids=self.input_batch.req_ids, - req_id_to_index=self.input_batch.req_id_to_index, - sampled_token_ids=[], - logprobs=None, - prompt_logprobs_dict={}, - pooler_output=pooler_output, - **extra_args, - ) + modelrunner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + spec_token_ids=None, + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + **extra_args, + ) return modelrunner_output + def _pool( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: int, + num_scheduled_tokens_np: np.ndarray, + finished_sending: Optional[set[str]] = None, + finished_recving: Optional[set[str]] = None, + kv_connector_output: Optional["KVConnectorOutput"] = None, + ) -> ModelRunnerOutput: + assert self.input_batch.num_reqs ==\ + len(self.input_batch.pooling_params), \ + "Either all or none of the requests in" \ + " a batch must be pooling request" + + hidden_states = hidden_states[:num_scheduled_tokens] + pooling_metadata = self.input_batch.pooling_metadata + pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(), + device=hidden_states.device) + seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs] + + # Pooling models D2H & synchronize occurs in pooler.py:build_output + raw_pooler_output = self.model.pooler( + hidden_states=hidden_states, pooling_metadata=pooling_metadata) + + pooler_output: list[Optional[torch.Tensor]] = [] + for raw_output, seq_len, prompt_len in zip( + raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens): + + if seq_len == prompt_len: + pooler_output.append(raw_output.data) + else: + pooler_output.append(None) + + return ModelRunnerOutput( + req_ids=self.input_batch.req_ids, + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=[], + logprobs=None, + prompt_logprobs_dict={}, + pooler_output=pooler_output, + kv_connector_output=kv_connector_output, + ) + @torch.inference_mode() def execute_model( self, @@ -1684,11 +1716,18 @@ class NPUModelRunner(LoRAModelRunnerMixin): logits = None else: if self.input_batch.pooling_params: - return self._pool( - hidden_states, - scheduler_output.total_num_scheduled_tokens, - num_scheduled_tokens_np, finished_sending, - finished_recving, kv_connector_output) + if vllm_version_is("0.10.1.1"): + return self._pool_v010( + hidden_states, + scheduler_output.total_num_scheduled_tokens, + num_scheduled_tokens_np, finished_sending, + finished_recving, kv_connector_output) + else: + return self._pool( + hidden_states, + scheduler_output.total_num_scheduled_tokens, + num_scheduled_tokens_np, finished_sending, + finished_recving, kv_connector_output) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) if broadcast_pp_output: @@ -2141,10 +2180,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): torch.split(hidden_states, num_scheduled_tokens_list)) req_num_tokens = num_tokens // num_reqs - dummy_prompt_lens = torch.tensor( - [h.shape[0] for h in hidden_states_list], - device=self.device, - ) dummy_token_ids = torch.zeros((num_reqs, req_num_tokens), dtype=torch.int32, device=self.device) @@ -2153,25 +2188,55 @@ class NPUModelRunner(LoRAModelRunnerMixin): dummy_pooling_params = PoolingParams(task=task) to_update = model.pooler.get_pooling_updates(task) to_update.apply(dummy_pooling_params) + if vllm_version_is("0.10.1.1"): + dummy_prompt_lens = torch.tensor( + [h.shape[0] for h in hidden_states_list], + device=self.device, + ) + dummy_metadata = PoolingMetadata( + prompt_lens=dummy_prompt_lens, + prompt_token_ids=dummy_token_ids, + pooling_params=[dummy_pooling_params] * num_reqs, + ) - dummy_metadata = PoolingMetadata( - prompt_lens=dummy_prompt_lens, - prompt_token_ids=dummy_token_ids, - pooling_params=[dummy_pooling_params] * num_reqs, - ) + try: + return model.pooler(hidden_states=hidden_states_list, + pooling_metadata=dummy_metadata) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "NPU out of memory occurred when warming up pooler " + f"({task=}) with {num_reqs} dummy requests. Please try " + "lowering `max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine.") from e + else: + raise e + else: + dummy_prompt_lens = torch.tensor( + num_scheduled_tokens_list, + device="cpu", + ) + dummy_metadata = PoolingMetadata( + prompt_lens=dummy_prompt_lens, + prompt_token_ids=dummy_token_ids, + pooling_params=[dummy_pooling_params] * num_reqs, + ) - try: - return model.pooler(hidden_states=hidden_states_list, - pooling_metadata=dummy_metadata) - except RuntimeError as e: - if 'out of memory' in str(e): - raise RuntimeError( - "NPU out of memory occurred when warming up pooler " - f"({task=}) with {num_reqs} dummy requests. Please try " - "lowering `max_num_seqs` or `gpu_memory_utilization` when " - "initializing the engine.") from e - else: - raise e + dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list, + device=hidden_states.device) + + try: + return model.pooler(hidden_states=hidden_states, + pooling_metadata=dummy_metadata) + except RuntimeError as e: + if 'out of memory' in str(e): + raise RuntimeError( + "CUDA out of memory occurred when warming up pooler " + f"({task=}) with {num_reqs} dummy requests. Please try " + "lowering `max_num_seqs` or `gpu_memory_utilization` when " + "initializing the engine.") from e + else: + raise e @torch.inference_mode() def _dummy_pooler_run( diff --git a/vllm_ascend/worker/npu_input_batch.py b/vllm_ascend/worker/npu_input_batch.py index d4a6298..7e1243a 100644 --- a/vllm_ascend/worker/npu_input_batch.py +++ b/vllm_ascend/worker/npu_input_batch.py @@ -39,6 +39,8 @@ from vllm.v1.spec_decode.utils import is_spec_decode_unsupported from vllm.v1.utils import copy_slice from vllm.v1.worker.block_table import MultiGroupBlockTable +from vllm_ascend.utils import vllm_version_is + @dataclass class CachedRequestState: @@ -724,13 +726,20 @@ class InputBatch: pooling_params = [ self.pooling_params[req_id] for req_id in self.req_ids ] - - return PoolingMetadata( - prompt_lens=torch.from_numpy( - self.num_prompt_tokens[:self.num_reqs]).to(self.device), - prompt_token_ids=self.sampling_metadata.prompt_token_ids, - pooling_params=pooling_params, - ) + if vllm_version_is("0.10.1.1"): + return PoolingMetadata( + prompt_lens=torch.from_numpy( + self.num_prompt_tokens[:self.num_reqs]).to(self.device), + prompt_token_ids=self.sampling_metadata.prompt_token_ids, + pooling_params=pooling_params, + ) + else: + return PoolingMetadata( + prompt_lens=torch.from_numpy( + self.num_prompt_tokens[:self.num_reqs]), + prompt_token_ids=self.sampling_metadata.prompt_token_ids, + pooling_params=pooling_params, + ) def _make_prompt_token_ids_tensor(self) -> torch.Tensor: max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()