init
This commit is contained in:
33
vllm/transformers_utils/tokenizer_group/__init__.py
Normal file
33
vllm/transformers_utils/tokenizer_group/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.executor.ray_utils import ray
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
|
||||
TokenizerGroup)
|
||||
|
||||
if ray:
|
||||
from vllm.transformers_utils.tokenizer_group.ray_tokenizer_group import (
|
||||
RayTokenizerGroupPool)
|
||||
else:
|
||||
RayTokenizerGroupPool = None # type: ignore
|
||||
|
||||
|
||||
def get_tokenizer_group(tokenizer_pool_config: Optional[TokenizerPoolConfig],
|
||||
**init_kwargs) -> BaseTokenizerGroup:
|
||||
if tokenizer_pool_config is None:
|
||||
return TokenizerGroup(**init_kwargs)
|
||||
if tokenizer_pool_config.pool_type == "ray":
|
||||
if RayTokenizerGroupPool is None:
|
||||
raise ImportError(
|
||||
"RayTokenizerGroupPool is not available. Please install "
|
||||
"the ray package to use the Ray tokenizer group pool.")
|
||||
return RayTokenizerGroupPool.from_config(tokenizer_pool_config,
|
||||
**init_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown pool type: {tokenizer_pool_config.pool_type}")
|
||||
|
||||
|
||||
__all__ = ["get_tokenizer_group", "BaseTokenizerGroup"]
|
||||
@@ -0,0 +1,55 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
|
||||
class BaseTokenizerGroup(ABC):
|
||||
"""A group of tokenizers that can be used for LoRA adapters."""
|
||||
|
||||
@abstractmethod
|
||||
def ping(self) -> bool:
|
||||
"""Check if the tokenizer group is alive."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_max_input_len(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
"""Get the maximum input length for the LoRA request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def encode(self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def encode_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_lora_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
"""Get a tokenizer for a LoRA request."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_lora_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
"""Get a tokenizer for a LoRA request."""
|
||||
pass
|
||||
169
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
Normal file
169
vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import TokenizerPoolConfig
|
||||
from vllm.executor.ray_utils import ray
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
|
||||
TokenizerGroup)
|
||||
|
||||
|
||||
class RayTokenizerGroupPool(BaseTokenizerGroup):
|
||||
"""A Ray-based pool of TokenizerGroups for async tokenization."""
|
||||
|
||||
# Class to use for workers making up the pool.
|
||||
_worker_cls = TokenizerGroup
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, tokenizer_pool_config: TokenizerPoolConfig,
|
||||
**init_kwargs) -> "RayTokenizerGroupPool":
|
||||
ray_actor_options = (tokenizer_pool_config.extra_config or {
|
||||
"num_cpus": 0
|
||||
})
|
||||
ray_actor_options.setdefault(
|
||||
"scheduling_strategy",
|
||||
NodeAffinitySchedulingStrategy(
|
||||
node_id=ray.get_runtime_context().get_node_id(), soft=True))
|
||||
|
||||
# Carry over the env vars to the actors.
|
||||
# This is necessary for API keys and such.
|
||||
ray_actor_options.setdefault("runtime_env", {})
|
||||
_carry_over_env_vars_to_runtime_env(ray_actor_options["runtime_env"])
|
||||
|
||||
init_kwargs["num_actors"] = tokenizer_pool_config.pool_size
|
||||
init_kwargs["ray_actor_options"] = ray_actor_options
|
||||
|
||||
return cls(**init_kwargs)
|
||||
|
||||
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
|
||||
max_input_length: Optional[int], num_actors: int,
|
||||
ray_actor_options: dict, **tokenizer_config):
|
||||
# Store a local copy of the TokenizerGroup for quick access
|
||||
# to underlying HF tokenizers.
|
||||
self._local_tokenizer_group = self._worker_cls(
|
||||
tokenizer_id=tokenizer_id,
|
||||
enable_lora=enable_lora,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_input_length=max_input_length,
|
||||
**tokenizer_config,
|
||||
)
|
||||
|
||||
ray_tokenizer_group_cls = ray.remote(
|
||||
self._worker_cls).options(**ray_actor_options)
|
||||
self.tokenizer_actors = [
|
||||
ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora,
|
||||
max_num_seqs, max_input_length,
|
||||
**tokenizer_config)
|
||||
for _ in range(num_actors)
|
||||
]
|
||||
self._idle_actors: Optional[asyncio.Queue] = None
|
||||
|
||||
@property
|
||||
def pool_size(self) -> int:
|
||||
return len(self.tokenizer_actors)
|
||||
|
||||
def ping(self):
|
||||
return ray.get(
|
||||
[actor.ping.remote() for actor in self.tokenizer_actors])
|
||||
|
||||
def _ensure_queue_initialized(self):
|
||||
if self._idle_actors is None:
|
||||
self._idle_actors = asyncio.Queue()
|
||||
for actor in self.tokenizer_actors:
|
||||
self._idle_actors.put_nowait(actor)
|
||||
|
||||
def encode(self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group.
|
||||
|
||||
We pick an idle actor and use it to encode the prompt.
|
||||
The actor is then put back in the queue for future use.
|
||||
This is blocking.
|
||||
"""
|
||||
self._ensure_queue_initialized()
|
||||
assert self._idle_actors is not None
|
||||
|
||||
if self._idle_actors.empty():
|
||||
raise RuntimeError("No idle actors available.")
|
||||
actor = self._idle_actors.get_nowait()
|
||||
try:
|
||||
ret = ray.get(
|
||||
actor.encode.remote(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request))
|
||||
finally:
|
||||
# Put the actor back in the queue.
|
||||
# This is done in a finally block to ensure that the actor is
|
||||
# always put back in the queue, even if an exception/cancellation
|
||||
# is raised.
|
||||
self._idle_actors.put_nowait(actor)
|
||||
return ret
|
||||
|
||||
async def encode_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||
"""Encode a prompt using the tokenizer group.
|
||||
|
||||
We pick an idle actor and use it to encode the prompt.
|
||||
If there are no idle actors, we wait until one becomes
|
||||
available.
|
||||
The actor is then put back in the queue for future use.
|
||||
This is non-blocking.
|
||||
"""
|
||||
self._ensure_queue_initialized()
|
||||
assert self._idle_actors is not None
|
||||
|
||||
actor = await self._idle_actors.get()
|
||||
try:
|
||||
ret = await actor.encode.remote(request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
finally:
|
||||
# Put the actor back in the queue.
|
||||
# This is done in a finally block to ensure that the actor is
|
||||
# always put back in the queue, even if an exception/cancellation
|
||||
# is raised.
|
||||
self._idle_actors.put_nowait(actor)
|
||||
return ret
|
||||
|
||||
def get_max_input_len(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
"""Get the maximum input length for the LoRA request."""
|
||||
return self._local_tokenizer_group.get_max_input_len(lora_request)
|
||||
|
||||
def get_lora_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
return self._local_tokenizer_group.get_lora_tokenizer(lora_request)
|
||||
|
||||
async def get_lora_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
return await self._local_tokenizer_group.get_lora_tokenizer_async(
|
||||
lora_request)
|
||||
|
||||
|
||||
def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
|
||||
"""Copy over all current process environment variables to the runtime_env.
|
||||
|
||||
The variables in runtime_env will take precedence over the current process
|
||||
environment variables.
|
||||
|
||||
runtime_env will be modified in place."""
|
||||
env_vars = os.environ.copy()
|
||||
runtime_env.setdefault("env_vars", {})
|
||||
env_vars.update(runtime_env["env_vars"])
|
||||
runtime_env["env_vars"] = env_vars
|
||||
78
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
Normal file
78
vllm/transformers_utils/tokenizer_group/tokenizer_group.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
|
||||
get_lora_tokenizer_async,
|
||||
get_tokenizer)
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
from vllm.utils import LRUCache
|
||||
|
||||
|
||||
class TokenizerGroup(BaseTokenizerGroup):
|
||||
"""A group of tokenizers that can be used for LoRA adapters."""
|
||||
|
||||
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
|
||||
max_input_length: Optional[int], **tokenizer_config):
|
||||
self.tokenizer_id = tokenizer_id
|
||||
self.tokenizer_config = tokenizer_config
|
||||
self.enable_lora = enable_lora
|
||||
self.max_input_length = max_input_length
|
||||
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
||||
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
|
||||
capacity=max_num_seqs) if enable_lora else None
|
||||
|
||||
def ping(self) -> bool:
|
||||
"""Check if the tokenizer group is alive."""
|
||||
return True
|
||||
|
||||
def get_max_input_len(self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> Optional[int]:
|
||||
"""Get the maximum input length for the LoRA request."""
|
||||
return self.max_input_length
|
||||
|
||||
def encode(self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||
tokenizer = self.get_lora_tokenizer(lora_request)
|
||||
return tokenizer.encode(prompt)
|
||||
|
||||
async def encode_async(
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: Optional[str] = None,
|
||||
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
||||
tokenizer = await self.get_lora_tokenizer_async(lora_request)
|
||||
return tokenizer.encode(prompt)
|
||||
|
||||
def get_lora_tokenizer(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
if not lora_request or not self.enable_lora:
|
||||
return self.tokenizer
|
||||
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||
tokenizer = (get_lora_tokenizer(
|
||||
lora_request, **self.tokenizer_config) or self.tokenizer)
|
||||
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
||||
return tokenizer
|
||||
else:
|
||||
return self.lora_tokenizers.get(lora_request.lora_int_id)
|
||||
|
||||
async def get_lora_tokenizer_async(
|
||||
self,
|
||||
lora_request: Optional[LoRARequest] = None
|
||||
) -> "PreTrainedTokenizer":
|
||||
if not lora_request or not self.enable_lora:
|
||||
return self.tokenizer
|
||||
if lora_request.lora_int_id not in self.lora_tokenizers:
|
||||
tokenizer = (await get_lora_tokenizer_async(
|
||||
lora_request, **self.tokenizer_config) or self.tokenizer)
|
||||
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
||||
return tokenizer
|
||||
else:
|
||||
return self.lora_tokenizers.get(lora_request.lora_int_id)
|
||||
Reference in New Issue
Block a user