Sync from v0.13
This commit is contained in:
230
vllm/pooling_params.py
Normal file
230
vllm/pooling_params.py
Normal file
@@ -0,0 +1,230 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Annotated, Any, Optional
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.config import ModelConfig, PoolerConfig
|
||||
from vllm.config.pooler import get_use_activation
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.tasks import PoolingTask
|
||||
|
||||
|
||||
class PoolingParams(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True,
|
||||
): # type: ignore[call-arg]
|
||||
"""API parameters for pooling models.
|
||||
|
||||
Attributes:
|
||||
truncate_prompt_tokens: Controls prompt truncation.
|
||||
Set to -1 to use the model's default truncation size.
|
||||
Set to k to keep only the last k tokens (left truncation).
|
||||
Set to None to disable truncation.
|
||||
dimensions: Reduce the dimensions of embeddings
|
||||
if model support matryoshka representation.
|
||||
normalize: Whether to normalize the embeddings outputs.
|
||||
softmax: softmax will be deprecated, please use use_activation instead.
|
||||
activation: activation will be deprecated, please use use_activation instead.
|
||||
use_activation: Whether to apply activation function to
|
||||
the classification outputs.
|
||||
"""
|
||||
|
||||
# --8<-- [start:common-pooling-params]
|
||||
truncate_prompt_tokens: Annotated[int, msgspec.Meta(ge=-1)] | None = None
|
||||
# --8<-- [end:common-pooling-params]
|
||||
|
||||
## for embeddings models
|
||||
# --8<-- [start:embedding-pooling-params]
|
||||
dimensions: int | None = None
|
||||
normalize: bool | None = None
|
||||
# --8<-- [end:embedding-pooling-params]
|
||||
|
||||
## for classification, scoring and rerank
|
||||
# --8<-- [start:classification-pooling-params]
|
||||
softmax: bool | None = None
|
||||
activation: bool | None = None
|
||||
use_activation: bool | None = None
|
||||
# --8<-- [end:classification-pooling-params]
|
||||
|
||||
## for step pooling models
|
||||
step_tag_id: int | None = None
|
||||
returned_token_ids: list[int] | None = None
|
||||
|
||||
## Internal use only
|
||||
task: PoolingTask | None = None
|
||||
requires_token_ids: bool = False
|
||||
skip_reading_prefix_cache: bool | None = None
|
||||
extra_kwargs: dict[str, Any] | None = None
|
||||
output_kind: RequestOutputKind = RequestOutputKind.FINAL_ONLY
|
||||
|
||||
@property
|
||||
def all_parameters(self) -> list[str]:
|
||||
return ["dimensions", "normalize", "use_activation"]
|
||||
|
||||
@property
|
||||
def valid_parameters(self):
|
||||
return {
|
||||
"embed": ["dimensions", "normalize"],
|
||||
"classify": ["use_activation"],
|
||||
"score": ["use_activation"],
|
||||
"token_embed": ["dimensions", "normalize"],
|
||||
"token_classify": ["use_activation"],
|
||||
}
|
||||
|
||||
def clone(self) -> "PoolingParams":
|
||||
"""Returns a deep copy of the PoolingParams instance."""
|
||||
return deepcopy(self)
|
||||
|
||||
def verify(
|
||||
self, task: PoolingTask, model_config: Optional["ModelConfig"] = None
|
||||
) -> None:
|
||||
if self.task is None:
|
||||
self.task = task
|
||||
elif self.task != task:
|
||||
msg = f"You cannot overwrite {self.task=!r} with {task=!r}!"
|
||||
raise ValueError(msg)
|
||||
|
||||
# raise deprecated warning for softmax and activation
|
||||
self.use_activation = get_use_activation(self)
|
||||
|
||||
# plugin task uses io_processor.parse_request to verify inputs,
|
||||
# skipping PoolingParams verify
|
||||
if self.task == "plugin":
|
||||
if self.skip_reading_prefix_cache is None:
|
||||
self.skip_reading_prefix_cache = True
|
||||
return
|
||||
|
||||
# NOTE: Task validation needs to done against the model instance,
|
||||
# which is not available in model config. So, it's not included
|
||||
# in this method
|
||||
self._merge_default_parameters(model_config)
|
||||
self._set_default_parameters(model_config)
|
||||
self._verify_valid_parameters()
|
||||
|
||||
def _merge_default_parameters(
|
||||
self, model_config: Optional["ModelConfig"] = None
|
||||
) -> None:
|
||||
if model_config is None:
|
||||
return
|
||||
|
||||
pooler_config = model_config.pooler_config
|
||||
if pooler_config is None:
|
||||
return
|
||||
|
||||
assert self.task is not None, "task must be set"
|
||||
valid_parameters = self.valid_parameters[self.task]
|
||||
|
||||
for k in valid_parameters:
|
||||
if getattr(pooler_config, k, None) is None:
|
||||
continue
|
||||
|
||||
if getattr(self, k, None) is None:
|
||||
setattr(self, k, getattr(pooler_config, k))
|
||||
|
||||
if self.skip_reading_prefix_cache is None:
|
||||
# If prefix caching is enabled,
|
||||
# the output of all pooling may less than n_prompt_tokens,
|
||||
# we need to skip reading cache at this request.
|
||||
if self.task in ["token_embed", "token_classify"]:
|
||||
self.skip_reading_prefix_cache = True
|
||||
else:
|
||||
self.skip_reading_prefix_cache = False
|
||||
|
||||
self._verify_step_pooling(pooler_config, valid_parameters)
|
||||
|
||||
def _verify_step_pooling(
|
||||
self, pooler_config: "PoolerConfig", valid_parameters: list[str]
|
||||
):
|
||||
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
|
||||
if pooler_config.pooling_type != "STEP":
|
||||
invalid_parameters = []
|
||||
for k in step_pooling_parameters:
|
||||
if getattr(self, k, None) is not None:
|
||||
invalid_parameters.append(k)
|
||||
|
||||
if invalid_parameters:
|
||||
raise ValueError(
|
||||
f"Task {self.task} only supports {valid_parameters} "
|
||||
f"parameters, does not support "
|
||||
f"{invalid_parameters} parameters"
|
||||
)
|
||||
else:
|
||||
for k in step_pooling_parameters:
|
||||
if getattr(pooler_config, k, None) is None:
|
||||
continue
|
||||
|
||||
if getattr(self, k, None) is None:
|
||||
setattr(self, k, getattr(pooler_config, k))
|
||||
|
||||
def _set_default_parameters(self, model_config: Optional["ModelConfig"]):
|
||||
if self.task in ["embed", "token_embed"]:
|
||||
if self.normalize is None:
|
||||
self.normalize = True
|
||||
|
||||
if self.dimensions is not None and model_config is not None:
|
||||
if not model_config.is_matryoshka:
|
||||
raise ValueError(
|
||||
f'Model "{model_config.served_model_name}" does not '
|
||||
f"support matryoshka representation, "
|
||||
f"changing output dimensions will lead to poor results."
|
||||
)
|
||||
|
||||
mds = model_config.matryoshka_dimensions
|
||||
if mds is not None:
|
||||
if self.dimensions not in mds:
|
||||
raise ValueError(
|
||||
f'Model "{model_config.served_model_name}" '
|
||||
f"only supports {str(mds)} matryoshka dimensions, "
|
||||
f"use other output dimensions will "
|
||||
f"lead to poor results."
|
||||
)
|
||||
elif self.dimensions < 1:
|
||||
raise ValueError("Dimensions must be greater than 0")
|
||||
|
||||
elif self.task in ["classify", "score", "token_classify"]:
|
||||
if self.use_activation is None:
|
||||
self.use_activation = True
|
||||
else:
|
||||
raise ValueError(f"Unknown pooling task: {self.task}")
|
||||
|
||||
def _verify_valid_parameters(self):
|
||||
assert self.task is not None, "task must be set"
|
||||
valid_parameters = self.valid_parameters[self.task]
|
||||
invalid_parameters = []
|
||||
for k in self.all_parameters:
|
||||
if k in valid_parameters:
|
||||
continue
|
||||
|
||||
if getattr(self, k, None) is not None:
|
||||
invalid_parameters.append(k)
|
||||
|
||||
if invalid_parameters:
|
||||
raise ValueError(
|
||||
f"Task {self.task} only supports {valid_parameters} "
|
||||
f"parameters, does not support "
|
||||
f"{invalid_parameters} parameters"
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"PoolingParams("
|
||||
f"task={self.task}, "
|
||||
f"normalize={self.normalize}, "
|
||||
f"dimensions={self.dimensions}, "
|
||||
f"use_activation={self.use_activation}, "
|
||||
f"step_tag_id={self.step_tag_id}, "
|
||||
f"returned_token_ids={self.returned_token_ids}, "
|
||||
f"requires_token_ids={self.requires_token_ids}, "
|
||||
f"skip_reading_prefix_cache={self.skip_reading_prefix_cache}, "
|
||||
f"truncate_prompt_tokens={self.truncate_prompt_tokens}, "
|
||||
f"extra_kwargs={self.extra_kwargs})"
|
||||
)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.output_kind == RequestOutputKind.FINAL_ONLY, (
|
||||
"For pooling output_kind has to be FINAL_ONLY"
|
||||
)
|
||||
Reference in New Issue
Block a user