[Embedding] Recover embedding function (#2483)
Fix broken embedding function. It's broken by
http://github.com/vllm-project/vllm/pull/23162
- vLLM version: v0.10.1.1
- vLLM main:
efc88cf64a
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -19,7 +19,6 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
|
||||||
from modelscope import snapshot_download # type: ignore[import-untyped]
|
from modelscope import snapshot_download # type: ignore[import-untyped]
|
||||||
|
|
||||||
from tests.e2e.conftest import HfRunner
|
from tests.e2e.conftest import HfRunner
|
||||||
@@ -50,8 +49,6 @@ def test_dummy():
|
|||||||
assert True
|
assert True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(
|
|
||||||
reason="TODO: revert me when pooler is adapted with the latest vllm main")
|
|
||||||
def test_embed_models_correctness(hf_runner, vllm_runner):
|
def test_embed_models_correctness(hf_runner, vllm_runner):
|
||||||
queries = ['What is the capital of China?', 'Explain gravity']
|
queries = ['What is the capital of China?', 'Explain gravity']
|
||||||
|
|
||||||
|
|||||||
@@ -1545,7 +1545,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
hidden_states, attn_metadata)
|
hidden_states, attn_metadata)
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
def _pool(
|
def _pool_v010(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
num_scheduled_tokens: int,
|
num_scheduled_tokens: int,
|
||||||
@@ -1579,29 +1579,61 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
else:
|
else:
|
||||||
pooler_output.append(None)
|
pooler_output.append(None)
|
||||||
extra_args = ({"kv_connector_output": kv_connector_output})
|
extra_args = ({"kv_connector_output": kv_connector_output})
|
||||||
if vllm_version_is("0.10.1.1"):
|
modelrunner_output = ModelRunnerOutput(
|
||||||
modelrunner_output = ModelRunnerOutput(
|
req_ids=self.input_batch.req_ids,
|
||||||
req_ids=self.input_batch.req_ids,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
sampled_token_ids=[],
|
||||||
sampled_token_ids=[],
|
spec_token_ids=None,
|
||||||
spec_token_ids=None,
|
logprobs=None,
|
||||||
logprobs=None,
|
prompt_logprobs_dict={},
|
||||||
prompt_logprobs_dict={},
|
pooler_output=pooler_output,
|
||||||
pooler_output=pooler_output,
|
**extra_args,
|
||||||
**extra_args,
|
)
|
||||||
)
|
|
||||||
else:
|
|
||||||
modelrunner_output = ModelRunnerOutput(
|
|
||||||
req_ids=self.input_batch.req_ids,
|
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
|
||||||
sampled_token_ids=[],
|
|
||||||
logprobs=None,
|
|
||||||
prompt_logprobs_dict={},
|
|
||||||
pooler_output=pooler_output,
|
|
||||||
**extra_args,
|
|
||||||
)
|
|
||||||
return modelrunner_output
|
return modelrunner_output
|
||||||
|
|
||||||
|
def _pool(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
num_scheduled_tokens: int,
|
||||||
|
num_scheduled_tokens_np: np.ndarray,
|
||||||
|
finished_sending: Optional[set[str]] = None,
|
||||||
|
finished_recving: Optional[set[str]] = None,
|
||||||
|
kv_connector_output: Optional["KVConnectorOutput"] = None,
|
||||||
|
) -> ModelRunnerOutput:
|
||||||
|
assert self.input_batch.num_reqs ==\
|
||||||
|
len(self.input_batch.pooling_params), \
|
||||||
|
"Either all or none of the requests in" \
|
||||||
|
" a batch must be pooling request"
|
||||||
|
|
||||||
|
hidden_states = hidden_states[:num_scheduled_tokens]
|
||||||
|
pooling_metadata = self.input_batch.pooling_metadata
|
||||||
|
pooling_metadata.build_pooling_cursor(num_scheduled_tokens_np.tolist(),
|
||||||
|
device=hidden_states.device)
|
||||||
|
seq_lens_cpu = self.seq_lens_cpu[:self.input_batch.num_reqs]
|
||||||
|
|
||||||
|
# Pooling models D2H & synchronize occurs in pooler.py:build_output
|
||||||
|
raw_pooler_output = self.model.pooler(
|
||||||
|
hidden_states=hidden_states, pooling_metadata=pooling_metadata)
|
||||||
|
|
||||||
|
pooler_output: list[Optional[torch.Tensor]] = []
|
||||||
|
for raw_output, seq_len, prompt_len in zip(
|
||||||
|
raw_pooler_output, seq_lens_cpu, pooling_metadata.prompt_lens):
|
||||||
|
|
||||||
|
if seq_len == prompt_len:
|
||||||
|
pooler_output.append(raw_output.data)
|
||||||
|
else:
|
||||||
|
pooler_output.append(None)
|
||||||
|
|
||||||
|
return ModelRunnerOutput(
|
||||||
|
req_ids=self.input_batch.req_ids,
|
||||||
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
|
sampled_token_ids=[],
|
||||||
|
logprobs=None,
|
||||||
|
prompt_logprobs_dict={},
|
||||||
|
pooler_output=pooler_output,
|
||||||
|
kv_connector_output=kv_connector_output,
|
||||||
|
)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
self,
|
self,
|
||||||
@@ -1684,11 +1716,18 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
logits = None
|
logits = None
|
||||||
else:
|
else:
|
||||||
if self.input_batch.pooling_params:
|
if self.input_batch.pooling_params:
|
||||||
return self._pool(
|
if vllm_version_is("0.10.1.1"):
|
||||||
hidden_states,
|
return self._pool_v010(
|
||||||
scheduler_output.total_num_scheduled_tokens,
|
hidden_states,
|
||||||
num_scheduled_tokens_np, finished_sending,
|
scheduler_output.total_num_scheduled_tokens,
|
||||||
finished_recving, kv_connector_output)
|
num_scheduled_tokens_np, finished_sending,
|
||||||
|
finished_recving, kv_connector_output)
|
||||||
|
else:
|
||||||
|
return self._pool(
|
||||||
|
hidden_states,
|
||||||
|
scheduler_output.total_num_scheduled_tokens,
|
||||||
|
num_scheduled_tokens_np, finished_sending,
|
||||||
|
finished_recving, kv_connector_output)
|
||||||
sample_hidden_states = hidden_states[logits_indices]
|
sample_hidden_states = hidden_states[logits_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
if broadcast_pp_output:
|
if broadcast_pp_output:
|
||||||
@@ -2141,10 +2180,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
torch.split(hidden_states, num_scheduled_tokens_list))
|
torch.split(hidden_states, num_scheduled_tokens_list))
|
||||||
req_num_tokens = num_tokens // num_reqs
|
req_num_tokens = num_tokens // num_reqs
|
||||||
|
|
||||||
dummy_prompt_lens = torch.tensor(
|
|
||||||
[h.shape[0] for h in hidden_states_list],
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
|
dummy_token_ids = torch.zeros((num_reqs, req_num_tokens),
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device)
|
device=self.device)
|
||||||
@@ -2153,25 +2188,55 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dummy_pooling_params = PoolingParams(task=task)
|
dummy_pooling_params = PoolingParams(task=task)
|
||||||
to_update = model.pooler.get_pooling_updates(task)
|
to_update = model.pooler.get_pooling_updates(task)
|
||||||
to_update.apply(dummy_pooling_params)
|
to_update.apply(dummy_pooling_params)
|
||||||
|
if vllm_version_is("0.10.1.1"):
|
||||||
|
dummy_prompt_lens = torch.tensor(
|
||||||
|
[h.shape[0] for h in hidden_states_list],
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
dummy_metadata = PoolingMetadata(
|
||||||
|
prompt_lens=dummy_prompt_lens,
|
||||||
|
prompt_token_ids=dummy_token_ids,
|
||||||
|
pooling_params=[dummy_pooling_params] * num_reqs,
|
||||||
|
)
|
||||||
|
|
||||||
dummy_metadata = PoolingMetadata(
|
try:
|
||||||
prompt_lens=dummy_prompt_lens,
|
return model.pooler(hidden_states=hidden_states_list,
|
||||||
prompt_token_ids=dummy_token_ids,
|
pooling_metadata=dummy_metadata)
|
||||||
pooling_params=[dummy_pooling_params] * num_reqs,
|
except RuntimeError as e:
|
||||||
)
|
if 'out of memory' in str(e):
|
||||||
|
raise RuntimeError(
|
||||||
|
"NPU out of memory occurred when warming up pooler "
|
||||||
|
f"({task=}) with {num_reqs} dummy requests. Please try "
|
||||||
|
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
||||||
|
"initializing the engine.") from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
dummy_prompt_lens = torch.tensor(
|
||||||
|
num_scheduled_tokens_list,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
dummy_metadata = PoolingMetadata(
|
||||||
|
prompt_lens=dummy_prompt_lens,
|
||||||
|
prompt_token_ids=dummy_token_ids,
|
||||||
|
pooling_params=[dummy_pooling_params] * num_reqs,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
dummy_metadata.build_pooling_cursor(num_scheduled_tokens_list,
|
||||||
return model.pooler(hidden_states=hidden_states_list,
|
device=hidden_states.device)
|
||||||
pooling_metadata=dummy_metadata)
|
|
||||||
except RuntimeError as e:
|
try:
|
||||||
if 'out of memory' in str(e):
|
return model.pooler(hidden_states=hidden_states,
|
||||||
raise RuntimeError(
|
pooling_metadata=dummy_metadata)
|
||||||
"NPU out of memory occurred when warming up pooler "
|
except RuntimeError as e:
|
||||||
f"({task=}) with {num_reqs} dummy requests. Please try "
|
if 'out of memory' in str(e):
|
||||||
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
raise RuntimeError(
|
||||||
"initializing the engine.") from e
|
"CUDA out of memory occurred when warming up pooler "
|
||||||
else:
|
f"({task=}) with {num_reqs} dummy requests. Please try "
|
||||||
raise e
|
"lowering `max_num_seqs` or `gpu_memory_utilization` when "
|
||||||
|
"initializing the engine.") from e
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def _dummy_pooler_run(
|
def _dummy_pooler_run(
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
|
|||||||
from vllm.v1.utils import copy_slice
|
from vllm.v1.utils import copy_slice
|
||||||
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
from vllm.v1.worker.block_table import MultiGroupBlockTable
|
||||||
|
|
||||||
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CachedRequestState:
|
class CachedRequestState:
|
||||||
@@ -724,13 +726,20 @@ class InputBatch:
|
|||||||
pooling_params = [
|
pooling_params = [
|
||||||
self.pooling_params[req_id] for req_id in self.req_ids
|
self.pooling_params[req_id] for req_id in self.req_ids
|
||||||
]
|
]
|
||||||
|
if vllm_version_is("0.10.1.1"):
|
||||||
return PoolingMetadata(
|
return PoolingMetadata(
|
||||||
prompt_lens=torch.from_numpy(
|
prompt_lens=torch.from_numpy(
|
||||||
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
|
||||||
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||||
pooling_params=pooling_params,
|
pooling_params=pooling_params,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
return PoolingMetadata(
|
||||||
|
prompt_lens=torch.from_numpy(
|
||||||
|
self.num_prompt_tokens[:self.num_reqs]),
|
||||||
|
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
|
||||||
|
pooling_params=pooling_params,
|
||||||
|
)
|
||||||
|
|
||||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||||
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()
|
||||||
|
|||||||
Reference in New Issue
Block a user