[Embedding] Recover embedding function (#2483)

Fix broken embedding function. It's broken by
http://github.com/vllm-project/vllm/pull/23162

- vLLM version: v0.10.1.1
- vLLM main:
efc88cf64a

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-08-27 09:22:01 +08:00
committed by GitHub
parent 6a4ec186e7
commit f22077daa6
3 changed files with 129 additions and 58 deletions

View File

@@ -39,6 +39,8 @@ from vllm.v1.spec_decode.utils import is_spec_decode_unsupported
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import MultiGroupBlockTable
from vllm_ascend.utils import vllm_version_is
@dataclass
class CachedRequestState:
@@ -724,13 +726,20 @@ class InputBatch:
pooling_params = [
self.pooling_params[req_id] for req_id in self.req_ids
]
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)
if vllm_version_is("0.10.1.1"):
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]).to(self.device),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)
else:
return PoolingMetadata(
prompt_lens=torch.from_numpy(
self.num_prompt_tokens[:self.num_reqs]),
prompt_token_ids=self.sampling_metadata.prompt_token_ids,
pooling_params=pooling_params,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
max_prompt_len = self.num_prompt_tokens[:self.num_reqs].max()