[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it? Change as little existing code as possible to add v1 pooling task's support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to vllm-ascend, Considering the frequent changes in upstream interfaces, in order to decouple, so i move it here ### How was this patch tested? CI passed with new added/existing test, and I have a simple test was first conducted locally which is adapted from https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like bellow: ```python import os import torch from vllm import LLM os.environ["VLLM_USE_MODELSCOPE"]="True" def get_detailed_instruct(task_description: str, query: str) -> str: return f'Instruct: {task_description}\nQuery:{query}' # Each query must come with a one-sentence instruction that describes the task task = 'Given a web search query, retrieve relevant passages that answer the query' queries = [ get_detailed_instruct(task, 'What is the capital of China?'), get_detailed_instruct(task, 'Explain gravity') ] # No need to add instruction for retrieval documents documents = [ "The capital of China is Beijing.", "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun." ] input_texts = queries + documents model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed") outputs = model.embed(input_texts) embeddings = torch.tensor([o.outputs.embedding for o in outputs]) scores = (embeddings[:2] @ embeddings[2:].T) print(scores.tolist()) # [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]] ``` --------- Signed-off-by: wangli <wangli858794774@gmail.com> Signed-off-by: wangli <858794774@qq.com> Co-authored-by: wangli <858794774@qq.com>
This commit is contained in:
@@ -47,6 +47,7 @@ from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.multimodal.utils import group_mm_inputs_by_modality
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingType
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
|
||||
@@ -62,7 +63,6 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||
from vllm.v1.spec_decode.utils import is_spec_decode_supported
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||
from vllm.v1.worker.utils import (gather_mm_placeholders,
|
||||
sanity_check_mm_encoder_outputs,
|
||||
@@ -74,12 +74,14 @@ from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.mla_v1 import CommonAttentionMetadata
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.pool.metadata import PoolingMetadata
|
||||
from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||
ProfileExecuteDuration, is_310p,
|
||||
vllm_version_is)
|
||||
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
||||
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr # type: ignore[import-untyped]
|
||||
@@ -177,6 +179,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.is_pooling_model = self.model_config.pooler_config is not None
|
||||
if self.is_multimodal_model:
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.model_config.get_hidden_size()),
|
||||
@@ -389,38 +392,29 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
for new_req_data in scheduler_output.scheduled_new_reqs:
|
||||
req_id = new_req_data.req_id
|
||||
sampling_params = new_req_data.sampling_params
|
||||
if sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
if sampling_params and \
|
||||
sampling_params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
generator = torch.Generator(device=self.device)
|
||||
generator.manual_seed(sampling_params.seed)
|
||||
else:
|
||||
generator = None
|
||||
if vllm_version_is("0.9.1"):
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
generator=generator,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
else:
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
generator=generator,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
|
||||
# For vllm v0.9.1 version compatibility, we check if
|
||||
# `pooling_params` is present in the new request data.
|
||||
pooling_params = getattr(new_req_data, "pooling_params", None)
|
||||
self.requests[req_id] = CachedRequestState(
|
||||
req_id=req_id,
|
||||
prompt_token_ids=new_req_data.prompt_token_ids,
|
||||
mm_inputs=new_req_data.mm_inputs,
|
||||
mm_positions=new_req_data.mm_positions,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
generator=generator,
|
||||
block_ids=new_req_data.block_ids,
|
||||
num_computed_tokens=new_req_data.num_computed_tokens,
|
||||
output_token_ids=[],
|
||||
lora_request=new_req_data.lora_request,
|
||||
)
|
||||
|
||||
# Only relevant for models using M-RoPE (e.g, Qwen2-VL)
|
||||
if self.uses_mrope:
|
||||
@@ -893,7 +887,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
) -> tuple[SpecDecodeMetadata, torch.Tensor, SpecDecodeMetadata,
|
||||
torch.Tensor, int, torch.Tensor, torch.Tensor]:
|
||||
torch.Tensor, int, torch.Tensor, torch.Tensor, np.ndarray]:
|
||||
# Check input valid
|
||||
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
assert total_num_scheduled_tokens > 0
|
||||
@@ -1173,7 +1167,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states, aux_hidden_states = hidden_states
|
||||
|
||||
return (attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
total_num_scheduled_tokens, sample_indices, aux_hidden_states)
|
||||
total_num_scheduled_tokens, sample_indices, aux_hidden_states,
|
||||
num_scheduled_tokens)
|
||||
|
||||
def _get_cumsum_and_arange(
|
||||
self,
|
||||
@@ -1431,6 +1426,47 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states, attn_metadata)
|
||||
return spec_token_ids
|
||||
|
||||
def _pool(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
num_scheduled_tokens: int,
|
||||
num_scheduled_tokens_np: np.ndarray,
|
||||
) -> ModelRunnerOutput:
|
||||
assert self.input_batch.num_reqs ==\
|
||||
len(self.input_batch.pooling_params), \
|
||||
"Either all or none of the requests in" \
|
||||
" a batch must be pooling request"
|
||||
|
||||
extracted_hidden_states = list(
|
||||
torch.split(hidden_states[:num_scheduled_tokens],
|
||||
num_scheduled_tokens_np.tolist()))
|
||||
|
||||
pooling_metadata = self.input_batch.pooling_metadata
|
||||
|
||||
raw_pooler_output = self.model.pooler(
|
||||
hidden_states=extracted_hidden_states,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
pooler_output: list[Optional[torch.Tensor]] = []
|
||||
seq_lens = self.seq_lens[:self.input_batch.num_reqs]
|
||||
for raw_output, seq_len, prompt_len in zip(
|
||||
raw_pooler_output, seq_lens, pooling_metadata.prompt_lens):
|
||||
|
||||
if seq_len == prompt_len:
|
||||
pooler_output.append(raw_output.data.cpu())
|
||||
else:
|
||||
pooler_output.append(None)
|
||||
|
||||
return ModelRunnerOutput(
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -1444,12 +1480,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
# Return empty ModelRunnerOuptut if there's no work to do.
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
(attn_metadata, hidden_states, spec_decode_metadata, positions,
|
||||
num_scheduled_tokens, sample_indices,
|
||||
aux_hidden_states) = (self._process_reqs(scheduler_output,
|
||||
intermediate_tensors))
|
||||
num_scheduled_tokens, sample_indices, aux_hidden_states,
|
||||
num_scheduled_tokens_np) = (self._process_reqs(
|
||||
scheduler_output, intermediate_tensors))
|
||||
|
||||
with ProfileExecuteDuration().capture_async("post process"):
|
||||
|
||||
if self.input_batch.pooling_params:
|
||||
return self._pool(hidden_states, num_scheduled_tokens,
|
||||
num_scheduled_tokens_np)
|
||||
logits = self.model.compute_logits(hidden_states[sample_indices],
|
||||
None)
|
||||
if self.use_eagle:
|
||||
@@ -1795,21 +1834,75 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
hidden_states = self._dummy_run(self.max_num_tokens)
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
logits = self.model.compute_logits(hidden_states, None)
|
||||
if self.is_pooling_model:
|
||||
output = self._dummy_pooler_run(hidden_states)
|
||||
else:
|
||||
# TODO: need to rum a dummy sampler for generate task
|
||||
hidden_states = hidden_states[logit_indices]
|
||||
output = self.model.compute_logits(hidden_states, None)
|
||||
else:
|
||||
logits = None
|
||||
output = None
|
||||
|
||||
NPUPlatform.synchronize()
|
||||
del hidden_states, logits
|
||||
del hidden_states, output
|
||||
self.encoder_cache.clear()
|
||||
gc.collect()
|
||||
|
||||
@torch.inference_mode()
|
||||
def _dummy_pooler_run(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
max_num_reqs = self.scheduler_config.max_num_seqs
|
||||
num_reqs = min(num_tokens, max_num_reqs)
|
||||
min_tokens_per_req = num_tokens // num_reqs
|
||||
num_scheduled_tokens_list = [min_tokens_per_req] * num_reqs
|
||||
num_scheduled_tokens_list[-1] += num_tokens % num_reqs
|
||||
assert sum(num_scheduled_tokens_list) == num_tokens
|
||||
assert len(num_scheduled_tokens_list) == num_reqs
|
||||
|
||||
hidden_states_list = list(
|
||||
torch.split(hidden_states, num_scheduled_tokens_list))
|
||||
|
||||
req_num_tokens = num_tokens // num_reqs
|
||||
|
||||
dummy_metadata = PoolingMetadata(
|
||||
prompt_lens=torch.tensor([h.shape[0] for h in hidden_states_list],
|
||||
device=self.device),
|
||||
prompt_token_ids=torch.zeros((num_reqs, req_num_tokens),
|
||||
dtype=torch.int32,
|
||||
device=self.device),
|
||||
pooling_params=[PoolingParams()] * num_reqs)
|
||||
|
||||
try:
|
||||
pooler_output = self.model.pooler(hidden_states=hidden_states_list,
|
||||
pooling_metadata=dummy_metadata)
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
raise RuntimeError(
|
||||
"NPU out of memory occurred when warming up pooler with "
|
||||
f"{num_reqs} dummy requests. Please try lowering "
|
||||
"`max_num_seqs` or `gpu_memory_utilization` when "
|
||||
"initializing the engine.") from e
|
||||
else:
|
||||
raise e
|
||||
return pooler_output
|
||||
|
||||
def load_model(self) -> None:
|
||||
logger.info("Starting to load model %s...", self.model_config.model)
|
||||
|
||||
with DeviceMemoryProfiler() as m: # noqa: SIM117
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
try:
|
||||
# For version compatibility, remove this after we abort vllm v0.9.1 support
|
||||
from vllm.model_executor.models.interfaces import \
|
||||
has_step_pooler # type: ignore
|
||||
if has_step_pooler(self.model):
|
||||
self.input_batch.logits_processing_needs_token_ids = True
|
||||
except ImportError:
|
||||
pass
|
||||
if self.drafter:
|
||||
logger.info("Loading drafter model...")
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
|
||||
Reference in New Issue
Block a user