[Embedding] Recover embedding function (#2483)
Fix broken embedding function. It's broken by
http://github.com/vllm-project/vllm/pull/23162
- vLLM version: v0.10.1.1
- vLLM main:
efc88cf64a
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -19,7 +19,6 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from modelscope import snapshot_download # type: ignore[import-untyped]
|
||||
|
||||
from tests.e2e.conftest import HfRunner
|
||||
@@ -50,8 +49,6 @@ def test_dummy():
|
||||
assert True
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="TODO: revert me when pooler is adapted with the latest vllm main")
|
||||
def test_embed_models_correctness(hf_runner, vllm_runner):
|
||||
queries = ['What is the capital of China?', 'Explain gravity']
|
||||
|
||||
|
||||
@@ -1545,7 +1545,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states, attn_metadata)
|
||||
return draft_token_ids
|
||||
|
||||
def _pool(
|
||||
def _pool_v010(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
@@ -1579,29 +1579,61 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
pooler_output.append(None)
|
||||
extra_args = ({"kv_connector_output": kv_connector_output})
|
||||
if vllm_version_is("0.10.1.1"):
|
||||
modelrunner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
**extra_args,
|
||||
)
|
||||
else:
|
||||
modelrunner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=[],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
**extra_args,
|
||||
)
|
||||
modelrunner_output = ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
**extra_args,
|
||||
)
|
||||
return modelrunner_output
|
||||
|
||||
def _pool(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
finished_sending: Optional[set[str]] = None,
|
||||
finished_recving: Optional[set[str]] = None,
|
||||
kv_connector_output: Optional["KVConnectorOutput"] = None,
|
||||
) -> ModelRunnerOutput:
|
||||
assert self.input_batch.num_reqs ==\
|
||||
len(self.input_batch.pooling_params), \
|
||||
"Either all or none of the requests in" \
|
||||
" a batch must be pooling request"
|
||||
|
||||
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||
pooling_metadata = self.input_batch.pooling_metadata
|
||||
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
|
||||
device=hidden_states.device)
|
||||
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
|
||||
|
||||
# Pooling models D2H & synchronize occurs in pooler.py:build_output
|
||||
raw_pooler_output = self.model.pooler(
|
||||
hidden_states=hidden_states, pooling_metadata=pooling_metadata)
|
||||
|
||||
pooler_output: list[Optional[torch.Tensor]] = []
|
||||
for raw_output, seq_len, prompt_len in zip(
|
||||
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
||||
|
||||
if seq_len == prompt_len:
|
||||
pooler_output.append(raw_output.data)
|
||||
else:
|
||||
pooler_output.append(None)
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=[],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
kv_connector_output=kv_connector_output,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -1684,11 +1716,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logits = None
|
||||
else:
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(
|
||||
hidden_states,
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving, kv_connector_output)
|
||||
if vllm_version_is("0.10.1.1"):
|
||||
return self._pool_v010(
|
||||
hidden_states,
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving, kv_connector_output)
|
||||
else:
|
||||
return self._pool(
|
||||
hidden_states,
|
||||
scheduler_output.total_num_scheduled_tokens,
|
||||
num_scheduled_tokens_np, finished_sending,
|
||||
finished_recving, kv_connector_output)
|
||||
sample_hidden_states = hidden_states[logits_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if broadcast_pp_output:
|
||||
@@ -2141,10 +2180,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch.split(hidden_states, num_scheduled_tokens_list))
|
||||
req_num_tokens = num_tokens // num_reqs
|
||||
|
||||
dummy_prompt_lens = torch.tensor(
|
||||
[h.shape[0] for h in hidden_states_list],
|
||||
device=self.device,
|
||||
)
|
||||
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
@@ -2153,25 +2188,55 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
dummy_pooling_params = PoolingParams(task=task)
|
||||
to_update = model.pooler.get_pooling_updates(task)
|
||||
to_update.apply(dummy_pooling_params)
|
||||
if vllm_version_is("0.10.1.1"):
|
||||
dummy_prompt_lens = torch.tensor(
|
||||
[h.shape[0] for h in hidden_states_list],
|
||||
device=self.device,
|
||||
)
|
||||
dummy_metadata = PoolingMetadata(
|
||||
prompt_lens=dummy_prompt_lens,
|
||||
prompt_token_ids=dummy_token_ids,
|
||||
pooling_params=[dummy_pooling_params] * num_reqs,
|
||||
)
|
||||
|
||||
dummy_metadata = PoolingMetadata(
|
||||
prompt_lens=dummy_prompt_lens,
|
||||
prompt_token_ids=dummy_token_ids,
|
||||
pooling_params=[dummy_pooling_params] * num_reqs,
|
||||
)
|
||||
try:
|
||||
return model.pooler(hidden_states=hidden_states_list,
|
||||
pooling_metadata=dummy_metadata)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
raise RuntimeError(
|
||||
"NPU out of memory occurred when warming up pooler "
|
||||
f"({task=}) with {num_reqs} dummy requests. Please try "
|
||||
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
||||
"initializing the engine.") from e
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
dummy_prompt_lens = torch.tensor(
|
||||
num_scheduled_tokens_list,
|
||||
device="cpu",
|
||||
)
|
||||
dummy_metadata = PoolingMetadata(
|
||||
prompt_lens=dummy_prompt_lens,
|
||||
prompt_token_ids=dummy_token_ids,
|
||||
pooling_params=[dummy_pooling_params] * num_reqs,
|
||||
)
|
||||
|
||||
try:
|
||||
return model.pooler(hidden_states=hidden_states_list,
|
||||
pooling_metadata=dummy_metadata)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
raise RuntimeError(
|
||||
"NPU out of memory occurred when warming up pooler "
|
||||
f"({task=}) with {num_reqs} dummy requests. Please try "
|
||||
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
||||
"initializing the engine.") from e
|
||||
else:
|
||||
raise e
|
||||
dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list,
|
||||
device=hidden_states.device)
|
||||
|
||||
try:
|
||||
return model.pooler(hidden_states=hidden_states,
|
||||
pooling_metadata=dummy_metadata)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
raise RuntimeError(
|
||||
"CUDA out of memory occurred when warming up pooler "
|
||||
f"({task=}) with {num_reqs} dummy requests. Please try "
|
||||
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
||||
"initializing the engine.") from e
|
||||
else:
|
||||
raise e
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_pooler_run(
|
||||
|
||||
@@ -39,6 +39,8 @@ from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedRequestState:
|
||||
@@ -724,13 +726,20 @@ class InputBatch:
|
||||
pooling_params = [
|
||||
self.pooling_params[req_id] for req_id in self.req_ids
|
||||
]
|
||||
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(
|
||||
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
if vllm_version_is("0.10.1.1"):
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(
|
||||
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
else:
|
||||
return PoolingMetadata(
|
||||
prompt_lens=torch.from_numpy(
|
||||
self.num_prompt_tokens[:self.num_reqs]),
|
||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||
|
||||
Reference in New Issue
Block a user