79 lines
3.2 KiB
Python
79 lines
3.2 KiB
Python
from typing import List, Optional
|
|
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.transformers_utils.tokenizer import (get_lora_tokenizer,
|
|
get_lora_tokenizer_async,
|
|
get_tokenizer)
|
|
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
|
BaseTokenizerGroup)
|
|
from vllm.utils import LRUCache
|
|
|
|
|
|
class TokenizerGroup(BaseTokenizerGroup):
|
|
"""A group of tokenizers that can be used for LoRA adapters."""
|
|
|
|
def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
|
|
max_input_length: Optional[int], **tokenizer_config):
|
|
self.tokenizer_id = tokenizer_id
|
|
self.tokenizer_config = tokenizer_config
|
|
self.enable_lora = enable_lora
|
|
self.max_input_length = max_input_length
|
|
self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config)
|
|
self.lora_tokenizers = LRUCache[PreTrainedTokenizer](
|
|
capacity=max_num_seqs) if enable_lora else None
|
|
|
|
def ping(self) -> bool:
|
|
"""Check if the tokenizer group is alive."""
|
|
return True
|
|
|
|
def get_max_input_len(self,
|
|
lora_request: Optional[LoRARequest] = None
|
|
) -> Optional[int]:
|
|
"""Get the maximum input length for the LoRA request."""
|
|
return self.max_input_length
|
|
|
|
def encode(self,
|
|
prompt: str,
|
|
request_id: Optional[str] = None,
|
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
|
tokenizer = self.get_lora_tokenizer(lora_request)
|
|
return tokenizer.encode(prompt)
|
|
|
|
async def encode_async(
|
|
self,
|
|
prompt: str,
|
|
request_id: Optional[str] = None,
|
|
lora_request: Optional[LoRARequest] = None) -> List[int]:
|
|
tokenizer = await self.get_lora_tokenizer_async(lora_request)
|
|
return tokenizer.encode(prompt)
|
|
|
|
def get_lora_tokenizer(
|
|
self,
|
|
lora_request: Optional[LoRARequest] = None
|
|
) -> "PreTrainedTokenizer":
|
|
if not lora_request or not self.enable_lora:
|
|
return self.tokenizer
|
|
if lora_request.lora_int_id not in self.lora_tokenizers:
|
|
tokenizer = (get_lora_tokenizer(
|
|
lora_request, **self.tokenizer_config) or self.tokenizer)
|
|
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
|
return tokenizer
|
|
else:
|
|
return self.lora_tokenizers.get(lora_request.lora_int_id)
|
|
|
|
async def get_lora_tokenizer_async(
|
|
self,
|
|
lora_request: Optional[LoRARequest] = None
|
|
) -> "PreTrainedTokenizer":
|
|
if not lora_request or not self.enable_lora:
|
|
return self.tokenizer
|
|
if lora_request.lora_int_id not in self.lora_tokenizers:
|
|
tokenizer = (await get_lora_tokenizer_async(
|
|
lora_request, **self.tokenizer_config) or self.tokenizer)
|
|
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
|
|
return tokenizer
|
|
else:
|
|
return self.lora_tokenizers.get(lora_request.lora_int_id)
|