init
This commit is contained in:
266
vllm/_C.py
Normal file
266
vllm/_C.py
Normal file
@@ -0,0 +1,266 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import ixformer
|
||||
import ixformer.functions as ixf_F
|
||||
from ixformer._C import ReduceOp
|
||||
from ixformer._C import _distributed as cdist
|
||||
from ixformer._C._distributed import is_initialized, get_default_comm_group
|
||||
from ixformer.contrib.torch.extension import ixformer_torch as ixft
|
||||
from ixformer.contrib.torch.data_type_mapping import torch_to_ixformer_dtype
|
||||
|
||||
|
||||
class ops():
|
||||
# activations
|
||||
@staticmethod
|
||||
def silu_and_mul(output, x):
|
||||
ixf_F.silu_and_mul(x, output)
|
||||
|
||||
@staticmethod
|
||||
def gelu_and_mul(output, x):
|
||||
ixf_F.gelu_and_mul(x, output)
|
||||
|
||||
@staticmethod
|
||||
def gelu_new(output, x):
|
||||
return F.gelu(x,"tanh")
|
||||
|
||||
@staticmethod
|
||||
def gelu_fast(output, x):
|
||||
return F.gelu(x,"tanh")
|
||||
|
||||
# rms norm
|
||||
@staticmethod
|
||||
def rms_norm(output, x, weight, epsilon):
|
||||
ixf_F.rms_norm(x, weight, output, epsilon)
|
||||
|
||||
@staticmethod
|
||||
def fused_add_rms_norm(input, residual, weight, epsilon, scale):
|
||||
ixf_F.fused_add_rms_norm(input, residual, weight, epsilon, scale)
|
||||
|
||||
# rotary embedding
|
||||
@staticmethod
|
||||
def rotary_embedding(positions, query, key, head_size,
|
||||
cos_sin_cache, is_neox_style):
|
||||
ixf_F.vllm_rotary_embedding_neox(positions, query, key, head_size,
|
||||
cos_sin_cache, is_neox_style)
|
||||
|
||||
# paged attention
|
||||
@staticmethod
|
||||
def paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes=None,
|
||||
kv_cache_dtype=None,
|
||||
):
|
||||
return ixf_F.vllm_single_query_cached_kv_attention(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes=None,
|
||||
kv_cache_dtype=None,
|
||||
use_sqrt_alibi=False,
|
||||
):
|
||||
return ixf_F.vllm_single_query_cached_kv_attention_v2(
|
||||
output,
|
||||
256,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
block_size,
|
||||
max_context_len,
|
||||
alibi_slopes,
|
||||
use_sqrt_alibi,
|
||||
)
|
||||
|
||||
# awq
|
||||
@staticmethod
|
||||
def awq_gemm(x, qweight, scales, qzeros, pack_factor):
|
||||
return ixf_F.quantized_linear(x,qweight,scales,"awq",32 // pack_factor,qzeros,None,group_size=128)
|
||||
|
||||
@staticmethod
|
||||
def awq_dequantize(qweight, scales, qzeros, holder1, holder2, holder3):
|
||||
raise NotImplementedError()
|
||||
|
||||
# gqt-q
|
||||
@staticmethod
|
||||
def gptq_shuffle(qweights,g_idx,weight_bits):
|
||||
return ixf_F.vllm_gptq_shuffle(qweights,g_idx)
|
||||
|
||||
@staticmethod
|
||||
def gptq_gemm(x, qweight, qzeros, scales, idx, status, weight_bits):
|
||||
batch = x.shape[0]
|
||||
if batch <= 8:
|
||||
return ixf_F.quantized_linear(x,qweight,scales,"gptq",4,qzeros,None,group_size=128)
|
||||
o_dtype_str = "fp16" if x.dtype == torch.half else "bf16"
|
||||
deq_w = ixf_F.quantized_weight_dequant(qweight,scales,"gptq",o_dtype_str,4,qzeros,group_size=128)
|
||||
return torch.matmul(x,deq_w)
|
||||
|
||||
# squeezellm
|
||||
@staticmethod
|
||||
def squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table):
|
||||
raise NotImplementedError()
|
||||
|
||||
# marlin
|
||||
@staticmethod
|
||||
def marlin_gemm(x_2d, qweight, scales, workspace, size_m, size_n, size_k):
|
||||
raise NotImplementedError()
|
||||
|
||||
# moe
|
||||
@staticmethod
|
||||
def moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
|
||||
expert_ids, num_tokens_post_pad):
|
||||
raise NotImplementedError()
|
||||
|
||||
# smoothquant
|
||||
@staticmethod
|
||||
def quant(output,input,scale):
|
||||
ixf_F.vllm_smooth_quant(output,input,scale)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def dequant(output,x,scale,global_scale):
|
||||
ixf_F.vllm_smooth_dequant(output,x,scale,global_scale)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def dequant_add_residual(output,x,residual,scale,global_scale):
|
||||
if isinstance(x,torch.Tensor):
|
||||
ixf_F.vllm_smooth_dequant_add_residual(output,x,residual,scale,global_scale)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def dequant_silu_and_mul_quant(output,x,gate_scale, up_scale, scale, temp = None):
|
||||
ixf_F.vllm_smooth_dequant_silu_and_mul_quant(output,x,gate_scale, up_scale, scale, temp)
|
||||
|
||||
@staticmethod
|
||||
def rms_norm_quant(output, input, weight, epsilon):
|
||||
return ixf_F.vllm_smooth_rms_norm_quant(output, input, weight, epsilon)
|
||||
|
||||
@staticmethod
|
||||
def fused_add_rms_norm_quant(output, input, residual, weight, epsilon):
|
||||
ixf_F.vllm_smooth_fused_add_rms_norm_quant(output, input, residual, weight, epsilon)
|
||||
|
||||
@staticmethod
|
||||
def dequant_fused_add_rms_norm_quant(output, input, residual, weight, epsilon, scale, global_scale):
|
||||
ixf_F.vllm_smooth_dequant_fused_add_rms_norm_quant(output, input, residual, weight, epsilon, scale, global_scale)
|
||||
|
||||
@staticmethod
|
||||
def dequant_rotary_embedding(positions, query, key, head_size,
|
||||
cos_sin_cache, query_out, key_out, query_scale, key_scale, is_neox_style):
|
||||
ixf_F.vllm_smooth_dequant_rotary_embedding_neox(positions, query, key, head_size,
|
||||
cos_sin_cache, query_out, key_out, query_scale, key_scale, is_neox_style)
|
||||
|
||||
@staticmethod
|
||||
def linear_a8_w8_o32_(x, weight, output):
|
||||
return ixf_F.linear_i8w8o32(x,weight,output)
|
||||
|
||||
|
||||
class cache_ops():
|
||||
|
||||
@staticmethod
|
||||
def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping):
|
||||
ixf_F.vllm_cache_ops_reshape_and_cache(
|
||||
key, value, key_cache, value_cache, slot_mapping
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(key_caches, value_caches, block_mapping):
|
||||
ixf_F.vllm_copy_cache(
|
||||
key_caches, value_caches, block_mapping
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(src_key_cache, dst_key_cache, src_to_dst):
|
||||
ixf_F.vllm_swap_blocks(
|
||||
src_key_cache, dst_key_cache, src_to_dst
|
||||
)
|
||||
|
||||
class custom_ar():
|
||||
|
||||
IS_INIT:bool = False
|
||||
|
||||
@staticmethod
|
||||
def is_init():
|
||||
return_status = custom_ar.IS_INIT
|
||||
custom_ar.IS_INIT = True
|
||||
return return_status
|
||||
|
||||
@staticmethod
|
||||
def init_cumtom_ar():
|
||||
if not is_initialized(get_default_comm_group()):
|
||||
group = ixft.create_ixformer_group_from_pg()
|
||||
ixformer.cuda.set_device(torch.cuda.current_device())
|
||||
cdist.update_default_comm_group(group)
|
||||
cdist.ipc.init_communicator_by_nccl()
|
||||
|
||||
@staticmethod
|
||||
def all_reduce_reg(ptr,tensor,out = None):
|
||||
raise NotImplementedError()
|
||||
|
||||
@staticmethod
|
||||
def all_reduce_unreg(ptr,tensor,buffer,out = None):
|
||||
dtype = tensor.dtype
|
||||
if torch.is_tensor(tensor):
|
||||
dtype = torch_to_ixformer_dtype(dtype)
|
||||
|
||||
if out is None:
|
||||
out = tensor
|
||||
cdist.ipc.allreduce(
|
||||
tensor.data_ptr(), out.data_ptr(), dtype, tensor.numel(), ReduceOp.SUM
|
||||
)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def dispose():
|
||||
ixformer.distributed.destroy_process_group()
|
||||
|
||||
@staticmethod
|
||||
def should_custom_ar(tensor:torch.Tensor, max_size, world_size, full_nvlink):
|
||||
return cdist.ipc.should_custom_ar(tensor.numel(),tensor.element_size(),max_size,world_size)
|
||||
|
||||
class cuda_utils():
|
||||
@staticmethod
|
||||
def get_max_shared_memory_per_block_device_attribute(gpu):
|
||||
return 100000000
|
||||
29
vllm/__init__.py
Normal file
29
vllm/__init__.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
|
||||
import os
|
||||
|
||||
# By default, to avoid memory fragmentation, disable UMD mempool
|
||||
if os.getenv("UMD_ENABLEMEMPOOL") is None:
|
||||
os.environ["UMD_ENABLEMEMPOOL"] = "0"
|
||||
os.environ["NCCL_FORCESYNC_DISABLE"] = "1"
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.ray_utils import initialize_cluster
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
__version__ = "0.3.3"
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
"LLMEngine",
|
||||
"EngineArgs",
|
||||
"AsyncLLMEngine",
|
||||
"AsyncEngineArgs",
|
||||
"initialize_cluster",
|
||||
]
|
||||
BIN
vllm/__pycache__/_C.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/_C.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/_moe_C.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/_moe_C.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/block.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/block.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/config.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/config.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/logger.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/logger.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/outputs.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/outputs.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/prefix.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/prefix.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/sampling_params.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/sampling_params.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/sequence.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/sequence.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/test_utils.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/test_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
5
vllm/_moe_C.py
Normal file
5
vllm/_moe_C.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import torch
|
||||
import ixformer.functions as ixf_F
|
||||
|
||||
def topk_softmax(topk_weights,topk_ids,token_expert_indicies,gating_output):
|
||||
raise NotImplementedError()
|
||||
72
vllm/block.py
Normal file
72
vllm/block.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Token blocks."""
|
||||
from typing import List
|
||||
|
||||
from vllm.utils import Device
|
||||
|
||||
_BLANK_TOKEN_ID = -1
|
||||
|
||||
|
||||
class LogicalTokenBlock:
|
||||
"""A block that stores a contiguous chunk of tokens from left to right.
|
||||
|
||||
Logical blocks are used to represent the states of the corresponding
|
||||
physical blocks in the KV cache.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_number: int,
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.block_number = block_number
|
||||
self.block_size = block_size
|
||||
|
||||
self.token_ids = [_BLANK_TOKEN_ID] * block_size
|
||||
self.num_tokens = 0
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self.num_tokens == 0
|
||||
|
||||
def get_num_empty_slots(self) -> int:
|
||||
return self.block_size - self.num_tokens
|
||||
|
||||
def is_full(self) -> bool:
|
||||
return self.num_tokens == self.block_size
|
||||
|
||||
def append_tokens(self, token_ids: List[int]) -> None:
|
||||
assert len(token_ids) <= self.get_num_empty_slots()
|
||||
curr_idx = self.num_tokens
|
||||
self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids
|
||||
self.num_tokens += len(token_ids)
|
||||
|
||||
def get_token_ids(self) -> List[int]:
|
||||
return self.token_ids[:self.num_tokens]
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
assert self.num_tokens > 0
|
||||
return self.token_ids[self.num_tokens - 1]
|
||||
|
||||
|
||||
class PhysicalTokenBlock:
|
||||
"""Represents the state of a block in the KV cache."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device,
|
||||
block_number: int,
|
||||
block_size: int,
|
||||
) -> None:
|
||||
self.device = device
|
||||
self.block_number = block_number
|
||||
self.block_size = block_size
|
||||
|
||||
self.ref_count = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (f'PhysicalTokenBlock(device={self.device}, '
|
||||
f'block_number={self.block_number}, '
|
||||
f'ref_count={self.ref_count})')
|
||||
|
||||
|
||||
# Mapping: logical block number -> physical block.
|
||||
BlockTable = List[PhysicalTokenBlock]
|
||||
689
vllm/config.py
Normal file
689
vllm/config.py
Normal file
@@ -0,0 +1,689 @@
|
||||
from typing import Optional, Union, ClassVar
|
||||
from dataclasses import dataclass
|
||||
import os
|
||||
from packaging.version import Version
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_GB = 1 << 30
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""Configuration for the model.
|
||||
|
||||
Args:
|
||||
model: Name or path of the huggingface model to use.
|
||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||
available, and "slow" will always use the slow tokenizer.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
load_format: The format of the model weights to load:
|
||||
"auto" will try to load the weights in the safetensors format and
|
||||
fall back to the pytorch bin format if safetensors format is
|
||||
not available.
|
||||
"pt" will load the weights in the pytorch bin format.
|
||||
"safetensors" will load the weights in the safetensors format.
|
||||
"npcache" will load the weights in pytorch format and store
|
||||
a numpy cache to speed up the loading.
|
||||
"dummy" will initialize the weights with random values, which is
|
||||
mainly for profiling.
|
||||
dtype: Data type for model weights and activations. The "auto" option
|
||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||
for BF16 models.
|
||||
seed: Random seed for reproducibility.
|
||||
revision: The specific model version to use. It can be a branch name,
|
||||
a tag name, or a commit id. If unspecified, will use the default
|
||||
version.
|
||||
code_revision: The specific revision to use for the model code on
|
||||
Hugging Face Hub. It can be a branch name, a tag name, or a
|
||||
commit id. If unspecified, will use the default version.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
branch name, a tag name, or a commit id. If unspecified, will use
|
||||
the default version.
|
||||
max_model_len: Maximum length of a sequence (including prompt and
|
||||
output). If None, will be derived from the model.
|
||||
quantization: Quantization method that was used to quantize the model
|
||||
weights. If None, we assume the model weights are not quantized.
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
download_dir: Optional[str],
|
||||
load_format: str,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: Optional[int] = None,
|
||||
) -> None:
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.download_dir = download_dir
|
||||
self.load_format = load_format
|
||||
self.seed = seed
|
||||
self.revision = revision
|
||||
self.code_revision = code_revision
|
||||
self.tokenizer_revision = tokenizer_revision
|
||||
self.quantization = quantization
|
||||
self.enforce_eager = True
|
||||
# TODO align
|
||||
# Use graph cause a runtime error, for now, do not use cuda graph
|
||||
self.max_context_len_to_capture = max_context_len_to_capture
|
||||
|
||||
if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true":
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C
|
||||
if not os.path.exists(model):
|
||||
model_path = snapshot_download(model_id=model,
|
||||
cache_dir=download_dir,
|
||||
revision=revision)
|
||||
else:
|
||||
model_path = model
|
||||
self.model = model_path
|
||||
self.download_dir = model_path
|
||||
self.tokenizer = model_path
|
||||
|
||||
self.hf_config = get_config(self.model, trust_remote_code, revision,
|
||||
code_revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self.max_model_len = _get_and_verify_max_len(self.hf_config,
|
||||
max_model_len)
|
||||
self._verify_load_format()
|
||||
self._verify_tokenizer_mode()
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
supported_load_format = [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
]
|
||||
rocm_not_supported_load_format = []
|
||||
if load_format not in supported_load_format:
|
||||
raise ValueError(
|
||||
f"Unknown load format: {self.load_format}. Must be one of "
|
||||
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
||||
if is_hip() and load_format in rocm_not_supported_load_format:
|
||||
rocm_supported_load_format = [
|
||||
f for f in supported_load_format
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
]
|
||||
raise ValueError(
|
||||
f"load format \'{load_format}\' is not supported in ROCm. "
|
||||
f"Supported load format are "
|
||||
f"{rocm_supported_load_format}")
|
||||
|
||||
# TODO: Remove this check once HF updates the pt weights of Mixtral.
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
if "MixtralForCausalLM" in architectures and load_format == "pt":
|
||||
raise ValueError(
|
||||
"Currently, the 'pt' format is not supported for Mixtral. "
|
||||
"Please use the 'safetensors' format instead. ")
|
||||
self.load_format = load_format
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = self.tokenizer_mode.lower()
|
||||
if tokenizer_mode not in ["auto", "slow"]:
|
||||
raise ValueError(
|
||||
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
||||
"either 'auto' or 'slow'.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq", "gptq", "squeezellm", "marlin", "smoothquant"]
|
||||
rocm_not_supported_quantization = ["awq", "marlin"]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
hf_quant_config = getattr(self.hf_config, "quantization_config", None)
|
||||
if hf_quant_config is not None:
|
||||
|
||||
hf_quant_method = str(hf_quant_config["quant_method"]).lower()
|
||||
# If the GPTQ model is serialized in marlin format, use marlin.
|
||||
if (hf_quant_method == "gptq"
|
||||
and "is_marlin_format" in hf_quant_config
|
||||
and hf_quant_config["is_marlin_format"]):
|
||||
hf_quant_method = "marlin"
|
||||
if self.quantization is None:
|
||||
self.quantization = hf_quant_method
|
||||
elif self.quantization != hf_quant_method:
|
||||
raise ValueError(
|
||||
"Quantization method specified in the model config "
|
||||
f"({hf_quant_method}) does not match the quantization "
|
||||
f"method specified in the `quantization` argument "
|
||||
f"({self.quantization}).")
|
||||
|
||||
if self.quantization is not None:
|
||||
if self.quantization not in supported_quantization:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
if is_hip(
|
||||
) and self.quantization in rocm_not_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not supported "
|
||||
f"in ROCm.")
|
||||
if self.quantization != "marlin":
|
||||
logger.warning(
|
||||
f"{self.quantization} quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.")
|
||||
|
||||
def _verify_cuda_graph(self) -> None:
|
||||
if self.max_context_len_to_capture is None:
|
||||
self.max_context_len_to_capture = self.max_model_len
|
||||
self.max_context_len_to_capture = min(self.max_context_len_to_capture,
|
||||
self.max_model_len)
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
total_num_attention_heads = self.hf_config.num_attention_heads
|
||||
tensor_parallel_size = parallel_config.tensor_parallel_size
|
||||
if total_num_attention_heads % tensor_parallel_size != 0:
|
||||
raise ValueError(
|
||||
f"Total number of attention heads ({total_num_attention_heads})"
|
||||
" must be divisible by tensor parallel size "
|
||||
f"({tensor_parallel_size}).")
|
||||
|
||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
||||
raise ValueError(
|
||||
f"Total number of hidden layers ({total_num_hidden_layers}) "
|
||||
"must be divisible by pipeline parallel size "
|
||||
f"({pipeline_parallel_size}).")
|
||||
|
||||
def get_sliding_window(self) -> Optional[int]:
|
||||
return getattr(self.hf_config, "sliding_window", None)
|
||||
|
||||
def get_vocab_size(self) -> int:
|
||||
return self.hf_config.vocab_size
|
||||
|
||||
def get_hidden_size(self) -> int:
|
||||
return self.hf_config.hidden_size
|
||||
|
||||
def get_head_size(self) -> int:
|
||||
if hasattr(self.hf_config, "head_dim"):
|
||||
return self.hf_config.head_dim
|
||||
# FIXME(woosuk): This may not be true for all models.
|
||||
return self.hf_config.hidden_size // self.hf_config.num_attention_heads
|
||||
|
||||
def get_total_num_kv_heads(self) -> int:
|
||||
"""Returns the total number of KV heads."""
|
||||
# For GPTBigCode & Falcon:
|
||||
# NOTE: for falcon, when new_decoder_architecture is True, the
|
||||
# multi_query flag is ignored and we use n_head_kv for the number of
|
||||
# KV heads.
|
||||
falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"]
|
||||
new_decoder_arch_falcon = (
|
||||
self.hf_config.model_type in falcon_model_types
|
||||
and getattr(self.hf_config, "new_decoder_architecture", False))
|
||||
if not new_decoder_arch_falcon and getattr(self.hf_config,
|
||||
"multi_query", False):
|
||||
# Multi-query attention, only one KV head.
|
||||
# Currently, tensor parallelism is not supported in this case.
|
||||
return 1
|
||||
|
||||
attributes = [
|
||||
# For Falcon:
|
||||
"n_head_kv",
|
||||
"num_kv_heads",
|
||||
# For LLaMA-2:
|
||||
"num_key_value_heads",
|
||||
# For ChatGLM:
|
||||
"multi_query_group_num",
|
||||
]
|
||||
for attr in attributes:
|
||||
num_kv_heads = getattr(self.hf_config, attr, None)
|
||||
if num_kv_heads is not None:
|
||||
return num_kv_heads
|
||||
|
||||
# For non-grouped-query attention models, the number of KV heads is
|
||||
# equal to the number of attention heads.
|
||||
return self.hf_config.num_attention_heads
|
||||
|
||||
def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
|
||||
"""Returns the number of KV heads per GPU."""
|
||||
total_num_kv_heads = self.get_total_num_kv_heads()
|
||||
# If tensor parallelism is used, we divide the number of KV heads by
|
||||
# the tensor parallel size. We will replicate the KV heads in the
|
||||
# case where the number of KV heads is smaller than the tensor
|
||||
# parallel size so each GPU has at least one KV head.
|
||||
return max(1,
|
||||
total_num_kv_heads // parallel_config.tensor_parallel_size)
|
||||
|
||||
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
|
||||
total_num_hidden_layers = self.hf_config.num_hidden_layers
|
||||
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
|
||||
|
||||
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache.
|
||||
|
||||
Args:
|
||||
block_size: Size of a cache block in number of tokens.
|
||||
gpu_memory_utilization: Fraction of GPU memory to use for the
|
||||
vLLM execution.
|
||||
swap_space: Size of the CPU swap space per GPU (in GiB).
|
||||
cache_dtype: Data type for kv cache storage.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
gpu_memory_utilization: float,
|
||||
swap_space: int,
|
||||
cache_dtype: str,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.gpu_memory_utilization = gpu_memory_utilization
|
||||
self.swap_space_bytes = swap_space * _GB
|
||||
self.cache_dtype = cache_dtype
|
||||
self.sliding_window = sliding_window
|
||||
self._verify_args()
|
||||
self._verify_cache_dtype()
|
||||
|
||||
# Will be set after profiling.
|
||||
self.num_gpu_blocks = None
|
||||
self.num_cpu_blocks = None
|
||||
|
||||
def metrics_info(self):
|
||||
# convert cache_config to dict(key: str, value:str) for prometheus metrics info
|
||||
return {key: str(value) for key, value in self.__dict__.items()}
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.gpu_memory_utilization > 1.0:
|
||||
raise ValueError(
|
||||
"GPU memory utilization must be less than 1.0. Got "
|
||||
f"{self.gpu_memory_utilization}.")
|
||||
|
||||
def _verify_cache_dtype(self) -> None:
|
||||
if self.cache_dtype == "auto":
|
||||
pass
|
||||
elif self.cache_dtype == "fp8_e5m2":
|
||||
nvcc_cuda_version = get_nvcc_cuda_version()
|
||||
if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"):
|
||||
raise ValueError(
|
||||
"FP8 is not supported when cuda version is lower than 11.8."
|
||||
)
|
||||
device_name = torch.cuda.get_device_name()
|
||||
if "AMD" in device_name:
|
||||
raise NotImplementedError(
|
||||
"FP8_E5M2 KV Cache on AMD GPU has not been supported yet.")
|
||||
logger.info(
|
||||
"Using fp8_e5m2 data type to store kv cache. It reduces "
|
||||
"the GPU memory footprint and boosts the performance. "
|
||||
"But it may cause slight accuracy drop. "
|
||||
"Currently we only support fp8 without scaling factors and "
|
||||
"make e5m2 as a default format.")
|
||||
else:
|
||||
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
total_cpu_memory = get_cpu_memory()
|
||||
# FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel
|
||||
# group are in the same node. However, the GPUs may span multiple nodes.
|
||||
num_gpus_per_node = parallel_config.tensor_parallel_size
|
||||
cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node
|
||||
|
||||
msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of "
|
||||
f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is "
|
||||
"allocated for the swap space.")
|
||||
if cpu_memory_usage > 0.7 * total_cpu_memory:
|
||||
raise ValueError("Too large swap space. " + msg)
|
||||
elif cpu_memory_usage > 0.4 * total_cpu_memory:
|
||||
logger.warning("Possibly too large swap space. " + msg)
|
||||
|
||||
|
||||
class ParallelConfig:
|
||||
"""Configuration for the distributed execution.
|
||||
|
||||
Args:
|
||||
pipeline_parallel_size: Number of pipeline parallel groups.
|
||||
tensor_parallel_size: Number of tensor parallel groups.
|
||||
worker_use_ray: Whether to use Ray for model workers. Will be set to
|
||||
True if either pipeline_parallel_size or tensor_parallel_size is
|
||||
greater than 1.
|
||||
max_parallel_loading_workers: Maximum number of multiple batches
|
||||
when load model sequentially. To avoid RAM OOM when using tensor
|
||||
parallel and large models.
|
||||
disable_custom_all_reduce: Disable the custom all-reduce kernel and
|
||||
fall back to NCCL.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
pipeline_parallel_size: int,
|
||||
tensor_parallel_size: int,
|
||||
worker_use_ray: bool,
|
||||
max_parallel_loading_workers: Optional[int] = None,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
if is_neuron():
|
||||
# For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly.
|
||||
# Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload
|
||||
# to multiple NeuronCores.
|
||||
self.tensor_parallel_size = 1
|
||||
self.neuron_tp_degree = tensor_parallel_size
|
||||
else:
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.max_parallel_loading_workers = max_parallel_loading_workers
|
||||
self.disable_custom_all_reduce = disable_custom_all_reduce
|
||||
|
||||
self.world_size = pipeline_parallel_size * self.tensor_parallel_size
|
||||
# Ray worker is not supported for Neuron backend.
|
||||
if self.world_size > 1 and not is_neuron():
|
||||
self.worker_use_ray = True
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is not supported yet.")
|
||||
if not self.disable_custom_all_reduce and self.world_size > 1:
|
||||
if is_hip():
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.info(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported on AMD GPUs.")
|
||||
elif self.pipeline_parallel_size > 1:
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.info(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported with pipeline parallelism.")
|
||||
|
||||
# FIXME(woosuk): Fix the stability issues and re-enable the custom
|
||||
# all-reduce kernel.
|
||||
if not self.disable_custom_all_reduce and self.world_size > 1:
|
||||
self.disable_custom_all_reduce = True
|
||||
logger.info(
|
||||
"Custom all-reduce kernels are temporarily disabled due to "
|
||||
"stability issues. We will re-enable them once the issues are "
|
||||
"resolved.")
|
||||
|
||||
|
||||
class SchedulerConfig:
|
||||
"""Scheduler configuration.
|
||||
|
||||
Args:
|
||||
max_num_batched_tokens: Maximum number of tokens to be processed in
|
||||
a single iteration.
|
||||
max_num_seqs: Maximum number of sequences to be processed in a single
|
||||
iteration.
|
||||
max_model_len: Maximum length of a sequence (including prompt
|
||||
and generated text).
|
||||
max_paddings: Maximum number of paddings to be added to a batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: Optional[int],
|
||||
max_num_seqs: int,
|
||||
max_model_len: int,
|
||||
max_paddings: int,
|
||||
) -> None:
|
||||
if max_num_batched_tokens is not None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
else:
|
||||
# If max_model_len is too short, use 2048 as the default value for
|
||||
# higher throughput.
|
||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self.max_paddings = max_paddings
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.max_num_batched_tokens < self.max_model_len:
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
||||
f"smaller than max_model_len ({self.max_model_len}). "
|
||||
"This effectively limits the maximum sequence length to "
|
||||
"max_num_batched_tokens and makes vLLM reject longer "
|
||||
"sequences. Please increase max_num_batched_tokens or "
|
||||
"decrease max_model_len.")
|
||||
if self.max_num_batched_tokens < self.max_num_seqs:
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
||||
"be greater than or equal to max_num_seqs "
|
||||
f"({self.max_num_seqs}).")
|
||||
|
||||
|
||||
class DeviceConfig:
|
||||
|
||||
def __init__(self, device: str = "auto") -> None:
|
||||
if device == "auto":
|
||||
# Automated device type detection
|
||||
if torch.cuda.is_available():
|
||||
self.device_type = "cuda"
|
||||
elif is_neuron():
|
||||
self.device_type = "neuron"
|
||||
else:
|
||||
raise RuntimeError("No supported device detected.")
|
||||
else:
|
||||
# Device type is assigned explicitly
|
||||
self.device_type = device
|
||||
|
||||
# Some device types require processing inputs on CPU
|
||||
if self.device_type in ["neuron"]:
|
||||
self.device = torch.device("cpu")
|
||||
else:
|
||||
# Set device with device type
|
||||
self.device = torch.device(self.device_type)
|
||||
|
||||
@property
|
||||
def is_neuron(self):
|
||||
return self.device_type == "neuron"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAConfig:
|
||||
max_lora_rank: int
|
||||
max_loras: int
|
||||
max_cpu_loras: Optional[int] = None
|
||||
lora_dtype: Optional[torch.dtype] = None
|
||||
lora_extra_vocab_size: int = 256
|
||||
# This is a constant.
|
||||
lora_vocab_padding_size: ClassVar[int] = 256
|
||||
|
||||
def __post_init__(self):
|
||||
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
|
||||
possible_max_ranks = (8, 16, 32, 64)
|
||||
possible_lora_extra_vocab_size = (0, 256, 512)
|
||||
if self.max_lora_rank not in possible_max_ranks:
|
||||
raise ValueError(
|
||||
f"max_lora_rank ({self.max_lora_rank}) must be one of "
|
||||
f"{possible_max_ranks}.")
|
||||
if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
|
||||
raise ValueError(
|
||||
f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
|
||||
f"must be one of {possible_lora_extra_vocab_size}.")
|
||||
if self.max_loras < 1:
|
||||
raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
|
||||
if self.max_cpu_loras is None:
|
||||
self.max_cpu_loras = self.max_loras
|
||||
elif self.max_cpu_loras < self.max_loras:
|
||||
raise ValueError(
|
||||
f"max_cpu_loras ({self.max_cpu_loras}) must be >= "
|
||||
f"max_loras ({self.max_loras})")
|
||||
|
||||
def verify_with_model_config(self, model_config: ModelConfig):
|
||||
if self.lora_dtype in (None, "auto"):
|
||||
self.lora_dtype = model_config.dtype
|
||||
elif isinstance(self.lora_dtype, str):
|
||||
self.lora_dtype = getattr(torch, self.lora_dtype)
|
||||
if model_config.quantization is not None:
|
||||
raise ValueError(
|
||||
"LoRA is not supported with quantized models yet.")
|
||||
|
||||
def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
|
||||
if scheduler_config.max_num_batched_tokens > 65528:
|
||||
raise ValueError(
|
||||
"Due to limitations of the custom LoRA CUDA kernel, "
|
||||
"max_num_batched_tokens must be <= 65528 when "
|
||||
"LoRA is enabled.")
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
"float16": torch.float16,
|
||||
"float": torch.float32,
|
||||
"float32": torch.float32,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
|
||||
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
config: PretrainedConfig,
|
||||
dtype: Union[str, torch.dtype],
|
||||
) -> torch.dtype:
|
||||
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
|
||||
# because config.torch_dtype can be None.
|
||||
config_dtype = getattr(config, "torch_dtype", None)
|
||||
if config_dtype is None:
|
||||
config_dtype = torch.float32
|
||||
|
||||
if isinstance(dtype, str):
|
||||
dtype = dtype.lower()
|
||||
if dtype == "auto":
|
||||
if config_dtype == torch.float32:
|
||||
# Following the common practice, we use float16 for float32
|
||||
# models.
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
torch_dtype = config_dtype
|
||||
else:
|
||||
if dtype not in _STR_DTYPE_TO_TORCH_DTYPE:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype]
|
||||
elif isinstance(dtype, torch.dtype):
|
||||
torch_dtype = dtype
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
if is_hip() and torch_dtype == torch.float32:
|
||||
rocm_supported_dtypes = [
|
||||
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
|
||||
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
|
||||
]
|
||||
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
|
||||
f"Supported dtypes are {rocm_supported_dtypes}")
|
||||
|
||||
# Verify the dtype.
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
# Upcasting to float32 is allowed.
|
||||
pass
|
||||
elif config_dtype == torch.float32:
|
||||
# Downcasting from float32 to float16 or bfloat16 is allowed.
|
||||
pass
|
||||
else:
|
||||
# Casting between float16 and bfloat16 is allowed with a warning.
|
||||
logger.warning(f"Casting {config_dtype} to {torch_dtype}.")
|
||||
|
||||
return torch_dtype
|
||||
|
||||
|
||||
def _get_and_verify_max_len(
|
||||
hf_config: PretrainedConfig,
|
||||
max_model_len: Optional[int],
|
||||
) -> int:
|
||||
"""Get and verify the model's maximum length."""
|
||||
derived_max_model_len = float("inf")
|
||||
possible_keys = [
|
||||
# OPT
|
||||
"max_position_embeddings",
|
||||
# GPT-2
|
||||
"n_positions",
|
||||
# MPT
|
||||
"max_seq_len",
|
||||
# ChatGLM2
|
||||
"seq_length",
|
||||
# Others
|
||||
"model_max_length",
|
||||
"max_sequence_length",
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
for key in possible_keys:
|
||||
max_len_key = getattr(hf_config, key, None)
|
||||
if max_len_key is not None:
|
||||
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
||||
if derived_max_model_len == float("inf"):
|
||||
if max_model_len is not None:
|
||||
# If max_model_len is specified, we use it.
|
||||
return max_model_len
|
||||
|
||||
default_max_len = 2048
|
||||
logger.warning(
|
||||
"The model's config.json does not contain any of the following "
|
||||
"keys to determine the original maximum length of the model: "
|
||||
f"{possible_keys}. Assuming the model's maximum length is "
|
||||
f"{default_max_len}.")
|
||||
derived_max_model_len = default_max_len
|
||||
|
||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||
if rope_scaling is not None:
|
||||
assert "factor" in rope_scaling
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
if "type" in rope_scaling:
|
||||
rope_type = rope_scaling["type"]
|
||||
elif "rope_type" in rope_scaling:
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"rope_scaling must have a 'type' or 'rope_type' key.")
|
||||
|
||||
if rope_type == "yarn":
|
||||
derived_max_model_len = rope_scaling[
|
||||
"original_max_position_embeddings"]
|
||||
derived_max_model_len *= scaling_factor
|
||||
|
||||
if max_model_len is None:
|
||||
max_model_len = derived_max_model_len
|
||||
elif max_model_len > derived_max_model_len:
|
||||
raise ValueError(
|
||||
f"User-specified max_model_len ({max_model_len}) is greater than "
|
||||
f"the derived max_model_len ({max_len_key}={derived_max_model_len}"
|
||||
" in model's config.json). This may lead to incorrect model "
|
||||
"outputs or CUDA errors. Make sure the value is correct and "
|
||||
"within the model context size.")
|
||||
return int(max_model_len)
|
||||
0
vllm/core/__init__.py
Normal file
0
vllm/core/__init__.py
Normal file
BIN
vllm/core/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/core/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/core/__pycache__/block_manager.cpython-310.pyc
Normal file
BIN
vllm/core/__pycache__/block_manager.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/core/__pycache__/policy.cpython-310.pyc
Normal file
BIN
vllm/core/__pycache__/policy.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/core/__pycache__/scheduler.cpython-310.pyc
Normal file
BIN
vllm/core/__pycache__/scheduler.cpython-310.pyc
Normal file
Binary file not shown.
330
vllm/core/block_manager.py
Normal file
330
vllm/core/block_manager.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""A block manager that manages token blocks."""
|
||||
import enum
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from vllm.block import BlockTable, PhysicalTokenBlock
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
class BlockAllocator:
|
||||
"""Manages free physical token blocks for a device.
|
||||
|
||||
The allocator maintains a list of free blocks and allocates a block when
|
||||
requested. When a block is freed, its reference count is decremented. If
|
||||
the reference count becomes zero, the block is added back to the free list.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device: Device,
|
||||
block_size: int,
|
||||
num_blocks: int,
|
||||
) -> None:
|
||||
self.device = device
|
||||
self.block_size = block_size
|
||||
self.num_blocks = num_blocks
|
||||
|
||||
# Initialize the free blocks.
|
||||
self.free_blocks: BlockTable = []
|
||||
for i in range(num_blocks):
|
||||
block = PhysicalTokenBlock(device=device,
|
||||
block_number=i,
|
||||
block_size=block_size)
|
||||
self.free_blocks.append(block)
|
||||
|
||||
def allocate(self) -> PhysicalTokenBlock:
|
||||
if not self.free_blocks:
|
||||
raise ValueError("Out of memory! No free blocks are available.")
|
||||
block = self.free_blocks.pop()
|
||||
block.ref_count = 1
|
||||
return block
|
||||
|
||||
def free(self, block: PhysicalTokenBlock) -> None:
|
||||
if block.ref_count == 0:
|
||||
raise ValueError(f"Double free! {block} is already freed.")
|
||||
block.ref_count -= 1
|
||||
if block.ref_count == 0:
|
||||
self.free_blocks.append(block)
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
return len(self.free_blocks)
|
||||
|
||||
|
||||
class AllocStatus(enum.Enum):
|
||||
"""Result for BlockSpaceManager.can_allocate
|
||||
|
||||
1. Ok: seq_group can be allocated now.
|
||||
2. Later: seq_group cannot be allocated.
|
||||
The capacity of allocator is larger than seq_group required.
|
||||
3. Never: seq_group can never be allocated.
|
||||
The seq_group is too large to allocated in GPU.
|
||||
"""
|
||||
OK = enum.auto()
|
||||
LATER = enum.auto()
|
||||
NEVER = enum.auto()
|
||||
|
||||
|
||||
class BlockSpaceManager:
|
||||
"""Manages the mapping between logical and physical token blocks."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
watermark: float = 0.01,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_total_gpu_blocks = num_gpu_blocks
|
||||
self.num_total_cpu_blocks = num_cpu_blocks
|
||||
|
||||
self.block_sliding_window = None
|
||||
if sliding_window is not None:
|
||||
assert sliding_window % block_size == 0, (sliding_window,
|
||||
block_size)
|
||||
self.block_sliding_window = sliding_window // block_size
|
||||
|
||||
self.watermark = watermark
|
||||
assert watermark >= 0.0
|
||||
|
||||
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
self.gpu_allocator = BlockAllocator(Device.GPU, block_size,
|
||||
num_gpu_blocks)
|
||||
self.cpu_allocator = BlockAllocator(Device.CPU, block_size,
|
||||
num_cpu_blocks)
|
||||
# Mapping: seq_id -> BlockTable.
|
||||
self.block_tables: Dict[int, BlockTable] = {}
|
||||
|
||||
def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
num_required_blocks = len(seq.logical_token_blocks)
|
||||
|
||||
if seq_group.prefix is not None and seq_group.prefix.allocated:
|
||||
num_required_blocks -= seq_group.prefix.get_num_blocks()
|
||||
|
||||
if self.block_sliding_window is not None:
|
||||
num_required_blocks = min(num_required_blocks,
|
||||
self.block_sliding_window)
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
|
||||
# Use watermark to avoid frequent cache eviction.
|
||||
if (self.num_total_gpu_blocks - num_required_blocks <
|
||||
self.watermark_blocks):
|
||||
return AllocStatus.NEVER
|
||||
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
|
||||
return AllocStatus.OK
|
||||
else:
|
||||
return AllocStatus.LATER
|
||||
|
||||
def allocate(self, seq_group: SequenceGroup) -> None:
|
||||
# NOTE: Here we assume that all sequences in the group have the same
|
||||
# prompt.
|
||||
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
|
||||
|
||||
# Allocate new physical token blocks that will store the prompt tokens.
|
||||
num_prompt_blocks = len(seq.logical_token_blocks)
|
||||
|
||||
block_table: BlockTable = []
|
||||
prefix_block_table: BlockTable = []
|
||||
num_prefix_blocks = 0
|
||||
|
||||
prefix = seq_group.prefix
|
||||
if prefix is not None and prefix.allocated:
|
||||
# Prefix has already been allocated. Use the existing block table.
|
||||
num_prompt_blocks -= prefix.get_num_blocks()
|
||||
for block in prefix.block_table:
|
||||
block.ref_count += seq_group.num_seqs()
|
||||
block_table.append(block)
|
||||
|
||||
for logical_idx in range(num_prompt_blocks):
|
||||
if (self.block_sliding_window is not None
|
||||
and logical_idx >= self.block_sliding_window):
|
||||
block = block_table[logical_idx % self.block_sliding_window]
|
||||
else:
|
||||
block = self.gpu_allocator.allocate()
|
||||
# Set the reference counts of the token blocks.
|
||||
block.ref_count = seq_group.num_seqs()
|
||||
block_table.append(block)
|
||||
|
||||
if prefix is not None and not prefix.allocated:
|
||||
# Allocate blocks for the prefix, we will compute the prefix's
|
||||
# KV cache in this run.
|
||||
num_prefix_blocks = prefix.get_num_blocks()
|
||||
prefix_block_table = block_table[:num_prefix_blocks]
|
||||
for block in prefix_block_table:
|
||||
block.ref_count += 1
|
||||
prefix.set_block_table(prefix_block_table)
|
||||
|
||||
# Assign the block table for each sequence.
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
self.block_tables[seq.seq_id] = block_table.copy()
|
||||
|
||||
def can_append_slot(self, seq_group: SequenceGroup) -> bool:
|
||||
# Simple heuristic: If there is at least one free block
|
||||
# for each sequence, we can append.
|
||||
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
return num_seqs <= num_free_gpu_blocks
|
||||
|
||||
def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
|
||||
"""Allocate a physical slot for a new token."""
|
||||
logical_blocks = seq.logical_token_blocks
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
if len(block_table) < len(logical_blocks):
|
||||
if (self.block_sliding_window
|
||||
and len(block_table) >= self.block_sliding_window):
|
||||
# reuse a block
|
||||
block_table.append(block_table[len(block_table) %
|
||||
self.block_sliding_window])
|
||||
else:
|
||||
# The sequence has a new logical block.
|
||||
# Allocate a new physical block.
|
||||
block = self.gpu_allocator.allocate()
|
||||
block_table.append(block)
|
||||
return None
|
||||
|
||||
# We want to append the token to the last physical block.
|
||||
last_block = block_table[-1]
|
||||
assert last_block.device == Device.GPU
|
||||
if last_block.ref_count == 1:
|
||||
# Not shared with other sequences. Appendable.
|
||||
return None
|
||||
else:
|
||||
# The last block is shared with other sequences.
|
||||
# Copy on Write: Allocate a new block and copy the tokens.
|
||||
new_block = self.gpu_allocator.allocate()
|
||||
block_table[-1] = new_block
|
||||
self.gpu_allocator.free(last_block)
|
||||
return last_block.block_number, new_block.block_number
|
||||
|
||||
def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
# NOTE: fork does not allocate a new physical block.
|
||||
# Thus, it is always safe from OOM.
|
||||
src_block_table = self.block_tables[parent_seq.seq_id]
|
||||
self.block_tables[child_seq.seq_id] = src_block_table.copy()
|
||||
for block in src_block_table:
|
||||
block.ref_count += 1
|
||||
|
||||
def _get_physical_blocks(
|
||||
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
|
||||
# NOTE: Here, we assume that the physical blocks are only shared by
|
||||
# the sequences in the same group.
|
||||
blocks: Set[PhysicalTokenBlock] = set()
|
||||
for seq in seq_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
continue
|
||||
blocks.update(self.block_tables[seq.seq_id])
|
||||
return list(blocks)
|
||||
|
||||
def can_swap_in(self, seq_group: SequenceGroup) -> bool:
|
||||
blocks = self._get_physical_blocks(seq_group)
|
||||
num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED)
|
||||
num_free_blocks = self.gpu_allocator.get_num_free_blocks()
|
||||
# NOTE: Conservatively, we assume that every sequence will allocate
|
||||
# at least one free block right after the swap-in.
|
||||
# NOTE: This should match the logic in can_append_slot().
|
||||
num_required_blocks = len(blocks) + num_swapped_seqs
|
||||
return num_free_blocks - num_required_blocks >= self.watermark_blocks
|
||||
|
||||
def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
# CPU block -> GPU block.
|
||||
if seq_group.prefix is not None:
|
||||
# make sure to swap in the prefix first
|
||||
assert seq_group.prefix.allocated and seq_group.prefix.computed
|
||||
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
new_block_table: BlockTable = []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
if seq_group.prefix is not None:
|
||||
for block in seq_group.prefix.block_table:
|
||||
new_block_table.append(block)
|
||||
block.ref_count += 1
|
||||
|
||||
for cpu_block in block_table:
|
||||
if cpu_block in mapping:
|
||||
gpu_block = mapping[cpu_block]
|
||||
gpu_block.ref_count += 1
|
||||
else:
|
||||
gpu_block = self.gpu_allocator.allocate()
|
||||
mapping[cpu_block] = gpu_block
|
||||
new_block_table.append(gpu_block)
|
||||
# Free the CPU block swapped in to GPU.
|
||||
self.cpu_allocator.free(cpu_block)
|
||||
self.block_tables[seq.seq_id] = new_block_table
|
||||
|
||||
block_number_mapping = {
|
||||
cpu_block.block_number: gpu_block.block_number
|
||||
for cpu_block, gpu_block in mapping.items()
|
||||
}
|
||||
return block_number_mapping
|
||||
|
||||
def can_swap_out(self, seq_group: SequenceGroup) -> bool:
|
||||
blocks = self._get_physical_blocks(seq_group)
|
||||
return len(blocks) <= self.cpu_allocator.get_num_free_blocks()
|
||||
|
||||
def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
|
||||
# GPU block -> CPU block.
|
||||
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
new_block_table: BlockTable = []
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
|
||||
for gpu_block in block_table:
|
||||
if (seq_group.prefix is not None
|
||||
and gpu_block in seq_group.prefix.block_table):
|
||||
# NOTE: We do not swap out the prefix blocks for now.
|
||||
self.gpu_allocator.free(gpu_block)
|
||||
continue
|
||||
|
||||
if gpu_block in mapping:
|
||||
cpu_block = mapping[gpu_block]
|
||||
cpu_block.ref_count += 1
|
||||
else:
|
||||
cpu_block = self.cpu_allocator.allocate()
|
||||
mapping[gpu_block] = cpu_block
|
||||
new_block_table.append(cpu_block)
|
||||
# Free the GPU block swapped out to CPU.
|
||||
self.gpu_allocator.free(gpu_block)
|
||||
self.block_tables[seq.seq_id] = new_block_table
|
||||
|
||||
block_number_mapping = {
|
||||
gpu_block.block_number: cpu_block.block_number
|
||||
for gpu_block, cpu_block in mapping.items()
|
||||
}
|
||||
return block_number_mapping
|
||||
|
||||
def _free_block_table(self, block_table: BlockTable) -> None:
|
||||
for block in set(block_table):
|
||||
if block.device == Device.GPU:
|
||||
self.gpu_allocator.free(block)
|
||||
else:
|
||||
self.cpu_allocator.free(block)
|
||||
|
||||
def free(self, seq: Sequence) -> None:
|
||||
if seq.seq_id not in self.block_tables:
|
||||
# Already freed or haven't been scheduled yet.
|
||||
return
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
self._free_block_table(block_table)
|
||||
del self.block_tables[seq.seq_id]
|
||||
|
||||
def reset(self) -> None:
|
||||
for block_table in self.block_tables.values():
|
||||
self._free_block_table(block_table)
|
||||
self.block_tables.clear()
|
||||
|
||||
def get_block_table(self, seq: Sequence) -> List[int]:
|
||||
block_table = self.block_tables[seq.seq_id]
|
||||
return [block.block_number for block in block_table]
|
||||
|
||||
def get_num_free_gpu_blocks(self) -> int:
|
||||
return self.gpu_allocator.get_num_free_blocks()
|
||||
|
||||
def get_num_free_cpu_blocks(self) -> int:
|
||||
return self.cpu_allocator.get_num_free_blocks()
|
||||
47
vllm/core/policy.py
Normal file
47
vllm/core/policy.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from collections import deque
|
||||
from typing import Deque
|
||||
|
||||
from vllm.sequence import SequenceGroup
|
||||
|
||||
|
||||
class Policy:
|
||||
|
||||
def get_priority(
|
||||
self,
|
||||
now: float,
|
||||
seq_group: SequenceGroup,
|
||||
) -> float:
|
||||
raise NotImplementedError
|
||||
|
||||
def sort_by_priority(
|
||||
self,
|
||||
now: float,
|
||||
seq_groups: Deque[SequenceGroup],
|
||||
) -> Deque[SequenceGroup]:
|
||||
return deque(
|
||||
sorted(
|
||||
seq_groups,
|
||||
key=lambda seq_group: self.get_priority(now, seq_group),
|
||||
reverse=True,
|
||||
))
|
||||
|
||||
|
||||
class FCFS(Policy):
|
||||
|
||||
def get_priority(
|
||||
self,
|
||||
now: float,
|
||||
seq_group: SequenceGroup,
|
||||
) -> float:
|
||||
return now - seq_group.metrics.arrival_time
|
||||
|
||||
|
||||
class PolicyFactory:
|
||||
|
||||
_POLICY_REGISTRY = {
|
||||
'fcfs': FCFS,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_policy(cls, policy_name: str, **kwargs) -> Policy:
|
||||
return cls._POLICY_REGISTRY[policy_name](**kwargs)
|
||||
498
vllm/core/scheduler.py
Normal file
498
vllm/core/scheduler.py
Normal file
@@ -0,0 +1,498 @@
|
||||
from collections import deque
|
||||
import enum
|
||||
import time
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
|
||||
from vllm.core.policy import PolicyFactory
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceStatus)
|
||||
from vllm.prefix import PrefixPool
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PreemptionMode(enum.Enum):
|
||||
"""Preemption modes.
|
||||
|
||||
1. Swapping: Swap out the blocks of the preempted sequences to CPU memory
|
||||
and swap them back in when the sequences are resumed.
|
||||
2. Recomputation: Discard the blocks of the preempted sequences and
|
||||
recompute them when the sequences are resumed, treating the sequences as
|
||||
new prompts.
|
||||
"""
|
||||
SWAP = enum.auto()
|
||||
RECOMPUTE = enum.auto()
|
||||
|
||||
|
||||
class SchedulerOutputs:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduled_seq_groups: Iterable[SequenceGroup],
|
||||
prompt_run: bool,
|
||||
num_batched_tokens: int,
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
ignored_seq_groups: List[SequenceGroup],
|
||||
) -> None:
|
||||
self.scheduled_seq_groups = scheduled_seq_groups
|
||||
self.prompt_run = prompt_run
|
||||
self.num_batched_tokens = num_batched_tokens
|
||||
self.blocks_to_swap_in = blocks_to_swap_in
|
||||
self.blocks_to_swap_out = blocks_to_swap_out
|
||||
self.blocks_to_copy = blocks_to_copy
|
||||
# Swap in and swap out should never happen at the same time.
|
||||
assert not (blocks_to_swap_in and blocks_to_swap_out)
|
||||
self.ignored_seq_groups = ignored_seq_groups
|
||||
|
||||
self.num_loras = len(self.lora_requests)
|
||||
if self.num_loras > 0:
|
||||
self._sort_by_lora_ids()
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
# NOTE: We do not consider the ignored sequence groups.
|
||||
return (not self.scheduled_seq_groups and not self.blocks_to_swap_in
|
||||
and not self.blocks_to_swap_out and not self.blocks_to_copy)
|
||||
|
||||
def _sort_by_lora_ids(self) -> bool:
|
||||
self.scheduled_seq_groups = sorted(
|
||||
self.scheduled_seq_groups,
|
||||
key=lambda g: (g.lora_request.lora_int_id
|
||||
if g.lora_request else 0, g.request_id))
|
||||
|
||||
@property
|
||||
def lora_requests(self) -> Set[LoRARequest]:
|
||||
return {g.lora_request for g in self.scheduled_seq_groups}
|
||||
|
||||
|
||||
class Scheduler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
) -> None:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
# Note for LoRA scheduling: the current policy is extremely
|
||||
# simple and NOT fair. It can lead to starvation of some
|
||||
# LoRAs. This should be improved in the future.
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.prompt_limit = min(self.scheduler_config.max_model_len,
|
||||
self.scheduler_config.max_num_batched_tokens)
|
||||
|
||||
# Instantiate the scheduling policy.
|
||||
self.policy = PolicyFactory.get_policy(policy_name="fcfs")
|
||||
# Create the block space manager.
|
||||
self.block_manager = BlockSpaceManager(
|
||||
block_size=self.cache_config.block_size,
|
||||
num_gpu_blocks=self.cache_config.num_gpu_blocks,
|
||||
num_cpu_blocks=self.cache_config.num_cpu_blocks,
|
||||
sliding_window=self.cache_config.sliding_window)
|
||||
|
||||
# Create the prefix pool to cache the prefixes.
|
||||
self.prefix_pool = PrefixPool(self.cache_config.block_size)
|
||||
|
||||
# Sequence groups in the WAITING state.
|
||||
self.waiting: Deque[SequenceGroup] = deque()
|
||||
# Sequence groups in the RUNNING state.
|
||||
self.running: Deque[SequenceGroup] = deque()
|
||||
# Sequence groups in the SWAPPED state.
|
||||
self.swapped: Deque[SequenceGroup] = deque()
|
||||
|
||||
@property
|
||||
def lora_enabled(self) -> bool:
|
||||
return bool(self.lora_config)
|
||||
|
||||
def add_seq_group(self, seq_group: SequenceGroup) -> None:
|
||||
# Add sequence groups to the waiting queue.
|
||||
self.waiting.append(seq_group)
|
||||
|
||||
def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None:
|
||||
"""Aborts a sequence group with the given ID.
|
||||
|
||||
Check if the sequence group with the given ID
|
||||
is present in any of the state queue.
|
||||
If present, remove the sequence group from the state queue.
|
||||
Also, if any of the sequences in the sequence group is not finished,
|
||||
free the sequence with status `FINISHED_ABORTED`.
|
||||
Otherwise, do nothing.
|
||||
|
||||
Args:
|
||||
request_id: The ID(s) of the sequence group to abort.
|
||||
"""
|
||||
if isinstance(request_id, str):
|
||||
request_id = (request_id, )
|
||||
request_ids = set(request_id)
|
||||
for state_queue in [self.waiting, self.running, self.swapped]:
|
||||
aborted_groups: List[SequenceGroup] = []
|
||||
for seq_group in state_queue:
|
||||
if not request_ids:
|
||||
# Using 'break' here may add two extra iterations,
|
||||
# but is acceptable to reduce complexity .
|
||||
break
|
||||
if seq_group.request_id in request_ids:
|
||||
# Appending aborted group into pending list.
|
||||
aborted_groups.append(seq_group)
|
||||
request_ids.remove(seq_group.request_id)
|
||||
for aborted_group in aborted_groups:
|
||||
# Remove the sequence group from the state queue.
|
||||
state_queue.remove(aborted_group)
|
||||
for seq in aborted_group.get_seqs():
|
||||
if seq.is_finished():
|
||||
continue
|
||||
seq.status = SequenceStatus.FINISHED_ABORTED
|
||||
self.free_seq(seq)
|
||||
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
return self.waiting or self.running or self.swapped
|
||||
|
||||
def get_num_unfinished_seq_groups(self) -> int:
|
||||
return len(self.waiting) + len(self.running) + len(self.swapped)
|
||||
|
||||
def _schedule(self) -> SchedulerOutputs:
|
||||
# Blocks that need to be swapped or copied before model execution.
|
||||
blocks_to_swap_in: Dict[int, int] = {}
|
||||
blocks_to_swap_out: Dict[int, int] = {}
|
||||
blocks_to_copy: Dict[int, List[int]] = {}
|
||||
|
||||
# Fix the current time.
|
||||
now = time.monotonic()
|
||||
|
||||
# Join waiting sequences if possible.
|
||||
if not self.swapped:
|
||||
ignored_seq_groups: List[SequenceGroup] = []
|
||||
scheduled: List[SequenceGroup] = []
|
||||
# The total number of sequences on the fly, including the
|
||||
# requests in the generation phase.
|
||||
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||
for seq_group in self.running)
|
||||
curr_loras = set(
|
||||
seq_group.lora_int_id
|
||||
for seq_group in self.running) if self.lora_enabled else None
|
||||
seq_lens: List[int] = []
|
||||
|
||||
# Optimization: We do not sort the waiting queue since the preempted
|
||||
# sequence groups are added to the front and the new sequence groups
|
||||
# are added to the back.
|
||||
leftover_waiting_sequences = deque()
|
||||
while self.waiting:
|
||||
seq_group = self.waiting[0]
|
||||
waiting_seqs = seq_group.get_seqs(
|
||||
status=SequenceStatus.WAITING)
|
||||
assert len(waiting_seqs) == 1, (
|
||||
"Waiting sequence group should have only one prompt "
|
||||
"sequence.")
|
||||
num_prompt_tokens = waiting_seqs[0].get_len()
|
||||
if num_prompt_tokens > self.prompt_limit:
|
||||
logger.warning(
|
||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||
f" and exceeds limit of {self.prompt_limit}")
|
||||
for seq in waiting_seqs:
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
self.waiting.popleft()
|
||||
continue
|
||||
|
||||
# If the sequence group cannot be allocated, stop.
|
||||
can_allocate = self.block_manager.can_allocate(seq_group)
|
||||
if can_allocate == AllocStatus.LATER:
|
||||
break
|
||||
elif can_allocate == AllocStatus.NEVER:
|
||||
logger.warning(
|
||||
f"Input prompt ({num_prompt_tokens} tokens) is too long"
|
||||
f" and exceeds the capacity of block_manager")
|
||||
for seq in waiting_seqs:
|
||||
seq.status = SequenceStatus.FINISHED_IGNORED
|
||||
ignored_seq_groups.append(seq_group)
|
||||
self.waiting.popleft()
|
||||
continue
|
||||
|
||||
lora_int_id = 0
|
||||
if self.lora_enabled:
|
||||
lora_int_id = seq_group.lora_int_id
|
||||
if lora_int_id > 0 and lora_int_id not in curr_loras and len(
|
||||
curr_loras) >= self.lora_config.max_loras:
|
||||
# We don't have a space for another LoRA, so
|
||||
# we ignore this request for now.
|
||||
leftover_waiting_sequences.appendleft(seq_group)
|
||||
self.waiting.popleft()
|
||||
continue
|
||||
|
||||
# If the number of batched tokens exceeds the limit, stop.
|
||||
new_seq_lens = seq_lens + [num_prompt_tokens]
|
||||
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
|
||||
if (num_batched_tokens >
|
||||
self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||
if (num_curr_seqs + num_new_seqs >
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
|
||||
num_paddings = num_batched_tokens - sum(new_seq_lens)
|
||||
if num_paddings > self.scheduler_config.max_paddings:
|
||||
break
|
||||
seq_lens = new_seq_lens
|
||||
|
||||
if lora_int_id > 0:
|
||||
curr_loras.add(lora_int_id)
|
||||
self.waiting.popleft()
|
||||
self._allocate(seq_group)
|
||||
self.running.append(seq_group)
|
||||
num_curr_seqs += num_new_seqs
|
||||
scheduled.append(seq_group)
|
||||
|
||||
self.waiting.extendleft(leftover_waiting_sequences)
|
||||
|
||||
if scheduled or ignored_seq_groups:
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
scheduled_seq_groups=scheduled,
|
||||
prompt_run=True,
|
||||
num_batched_tokens=len(seq_lens) *
|
||||
max(seq_lens) if seq_lens else 0,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
ignored_seq_groups=ignored_seq_groups,
|
||||
)
|
||||
return scheduler_outputs
|
||||
|
||||
# NOTE(woosuk): Preemption happens only when there is no available slot
|
||||
# to keep all the sequence groups in the RUNNING state.
|
||||
# In this case, the policy is responsible for deciding which sequence
|
||||
# groups to preempt.
|
||||
self.running = self.policy.sort_by_priority(now, self.running)
|
||||
|
||||
# Reserve new token slots for the running sequence groups.
|
||||
running: Deque[SequenceGroup] = deque()
|
||||
preempted: List[SequenceGroup] = []
|
||||
while self.running:
|
||||
seq_group = self.running.popleft()
|
||||
while not self.block_manager.can_append_slot(seq_group):
|
||||
if self.running:
|
||||
# Preempt the lowest-priority sequence groups.
|
||||
victim_seq_group = self.running.pop()
|
||||
self._preempt(victim_seq_group, blocks_to_swap_out)
|
||||
preempted.append(victim_seq_group)
|
||||
else:
|
||||
# No other sequence groups can be preempted.
|
||||
# Preempt the current sequence group.
|
||||
self._preempt(seq_group, blocks_to_swap_out)
|
||||
preempted.append(seq_group)
|
||||
break
|
||||
else:
|
||||
# Append new slots to the sequence group.
|
||||
self._append_slot(seq_group, blocks_to_copy)
|
||||
running.append(seq_group)
|
||||
self.running = running
|
||||
|
||||
# Swap in the sequence groups in the SWAPPED state if possible.
|
||||
self.swapped = self.policy.sort_by_priority(now, self.swapped)
|
||||
if not preempted:
|
||||
num_curr_seqs = sum(seq_group.get_max_num_running_seqs()
|
||||
for seq_group in self.running)
|
||||
curr_loras = set(
|
||||
seq_group.lora_int_id
|
||||
for seq_group in self.running) if self.lora_enabled else None
|
||||
|
||||
leftover_swapped = deque()
|
||||
|
||||
while self.swapped:
|
||||
seq_group = self.swapped[0]
|
||||
lora_int_id = 0
|
||||
if self.lora_enabled:
|
||||
lora_int_id = seq_group.lora_int_id
|
||||
if lora_int_id > 0 and lora_int_id not in curr_loras and len(
|
||||
curr_loras) >= self.lora_config.max_loras:
|
||||
# We don't have a space for another LoRA, so
|
||||
# we ignore this request for now.
|
||||
leftover_swapped.appendleft(seq_group)
|
||||
self.swapped.popleft()
|
||||
continue
|
||||
|
||||
# If the sequence group cannot be swapped in, stop.
|
||||
if not self.block_manager.can_swap_in(seq_group):
|
||||
break
|
||||
|
||||
# The total number of sequences in the RUNNING state should not
|
||||
# exceed the maximum number of sequences.
|
||||
num_new_seqs = seq_group.get_max_num_running_seqs()
|
||||
if (num_curr_seqs + num_new_seqs >
|
||||
self.scheduler_config.max_num_seqs):
|
||||
break
|
||||
|
||||
if lora_int_id > 0:
|
||||
curr_loras.add(lora_int_id)
|
||||
self.swapped.popleft()
|
||||
self._swap_in(seq_group, blocks_to_swap_in)
|
||||
self._append_slot(seq_group, blocks_to_copy)
|
||||
num_curr_seqs += num_new_seqs
|
||||
self.running.append(seq_group)
|
||||
|
||||
self.swapped.extendleft(leftover_swapped)
|
||||
|
||||
# Each sequence in the generation phase only takes one token slot.
|
||||
# Therefore, the number of batched tokens is equal to the number of
|
||||
# sequences in the RUNNING state.
|
||||
num_batched_tokens = sum(
|
||||
seq_group.num_seqs(status=SequenceStatus.RUNNING)
|
||||
for seq_group in self.running)
|
||||
|
||||
scheduler_outputs = SchedulerOutputs(
|
||||
scheduled_seq_groups=self.running,
|
||||
prompt_run=False,
|
||||
num_batched_tokens=num_batched_tokens,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
ignored_seq_groups=[],
|
||||
)
|
||||
return scheduler_outputs
|
||||
|
||||
def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]:
|
||||
# Schedule sequence groups.
|
||||
# This function call changes the internal states of the scheduler
|
||||
# such as self.running, self.swapped, and self.waiting.
|
||||
scheduler_outputs = self._schedule()
|
||||
now = time.time()
|
||||
|
||||
# Create input data structures.
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
for seq_group in scheduler_outputs.scheduled_seq_groups:
|
||||
seq_group.maybe_set_first_scheduled_time(now)
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
block_tables: Dict[int, List[int]] = {}
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq_id = seq.seq_id
|
||||
seq_data[seq_id] = seq.data
|
||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
||||
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=seq_group.request_id,
|
||||
is_prompt=scheduler_outputs.prompt_run,
|
||||
seq_data=seq_data,
|
||||
sampling_params=seq_group.sampling_params,
|
||||
block_tables=block_tables,
|
||||
lora_request=seq_group.lora_request,
|
||||
prefix=seq_group.prefix,
|
||||
state=seq_group.state,
|
||||
)
|
||||
seq_group_metadata_list.append(seq_group_metadata)
|
||||
return seq_group_metadata_list, scheduler_outputs
|
||||
|
||||
def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None:
|
||||
self.block_manager.fork(parent_seq, child_seq)
|
||||
|
||||
def free_seq(self, seq: Sequence) -> None:
|
||||
self.block_manager.free(seq)
|
||||
|
||||
def free_finished_seq_groups(self) -> None:
|
||||
self.running = deque(seq_group for seq_group in self.running
|
||||
if not seq_group.is_finished())
|
||||
|
||||
def _allocate(self, seq_group: SequenceGroup) -> None:
|
||||
self.block_manager.allocate(seq_group)
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
def _append_slot(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_copy: Dict[int, List[int]],
|
||||
) -> None:
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
ret = self.block_manager.append_slot(seq)
|
||||
if ret is not None:
|
||||
src_block, dst_block = ret
|
||||
if src_block in blocks_to_copy:
|
||||
blocks_to_copy[src_block].append(dst_block)
|
||||
else:
|
||||
blocks_to_copy[src_block] = [dst_block]
|
||||
|
||||
def _preempt(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
preemption_mode: Optional[PreemptionMode] = None,
|
||||
) -> None:
|
||||
# If preemption mode is not specified, we determine the mode as follows:
|
||||
# We use recomputation by default since it incurs lower overhead than
|
||||
# swapping. However, when the sequence group has multiple sequences
|
||||
# (e.g., beam search), recomputation is not currently supported. In
|
||||
# such a case, we use swapping instead.
|
||||
# FIXME(woosuk): This makes our scheduling policy a bit bizarre.
|
||||
# As swapped sequences are prioritized over waiting sequences,
|
||||
# sequence groups with multiple sequences are implicitly prioritized
|
||||
# over sequence groups with a single sequence.
|
||||
# TODO(woosuk): Support recomputation for sequence groups with multiple
|
||||
# sequences. This may require a more sophisticated CUDA kernel.
|
||||
if preemption_mode is None:
|
||||
if seq_group.get_max_num_running_seqs() == 1:
|
||||
preemption_mode = PreemptionMode.RECOMPUTE
|
||||
else:
|
||||
preemption_mode = PreemptionMode.SWAP
|
||||
if preemption_mode == PreemptionMode.RECOMPUTE:
|
||||
self._preempt_by_recompute(seq_group)
|
||||
elif preemption_mode == PreemptionMode.SWAP:
|
||||
self._preempt_by_swap(seq_group, blocks_to_swap_out)
|
||||
else:
|
||||
raise AssertionError("Invalid preemption mode.")
|
||||
|
||||
def _preempt_by_recompute(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
) -> None:
|
||||
seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
|
||||
assert len(seqs) == 1
|
||||
for seq in seqs:
|
||||
seq.status = SequenceStatus.WAITING
|
||||
self.block_manager.free(seq)
|
||||
# NOTE: For FCFS, we insert the preempted sequence group to the front
|
||||
# of the waiting queue.
|
||||
self.waiting.appendleft(seq_group)
|
||||
|
||||
def _preempt_by_swap(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
) -> None:
|
||||
self._swap_out(seq_group, blocks_to_swap_out)
|
||||
self.swapped.append(seq_group)
|
||||
|
||||
def _swap_in(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_in: Dict[int, int],
|
||||
) -> None:
|
||||
mapping = self.block_manager.swap_in(seq_group)
|
||||
blocks_to_swap_in.update(mapping)
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED):
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
def _swap_out(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_swap_out: Dict[int, int],
|
||||
) -> None:
|
||||
if not self.block_manager.can_swap_out(seq_group):
|
||||
# FIXME(woosuk): Abort the sequence group instead of aborting the
|
||||
# entire engine.
|
||||
raise RuntimeError(
|
||||
"Aborted due to the lack of CPU swap space. Please increase "
|
||||
"the swap space to avoid this error.")
|
||||
mapping = self.block_manager.swap_out(seq_group)
|
||||
blocks_to_swap_out.update(mapping)
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
seq.status = SequenceStatus.SWAPPED
|
||||
0
vllm/engine/__init__.py
Normal file
0
vllm/engine/__init__.py
Normal file
BIN
vllm/engine/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/__pycache__/arg_utils.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/arg_utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/__pycache__/async_llm_engine.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/async_llm_engine.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/__pycache__/llm_engine.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/llm_engine.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/__pycache__/metrics.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/metrics.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/engine/__pycache__/ray_utils.cpython-310.pyc
Normal file
BIN
vllm/engine/__pycache__/ray_utils.cpython-310.pyc
Normal file
Binary file not shown.
341
vllm/engine/arg_utils.py
Normal file
341
vllm/engine/arg_utils.py
Normal file
@@ -0,0 +1,341 @@
|
||||
import argparse
|
||||
import dataclasses
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, LoRAConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineArgs:
|
||||
"""Arguments for vLLM engine."""
|
||||
model: str
|
||||
tokenizer: Optional[str] = None
|
||||
tokenizer_mode: str = 'auto'
|
||||
trust_remote_code: bool = False
|
||||
download_dir: Optional[str] = None
|
||||
load_format: str = 'auto'
|
||||
dtype: str = 'auto'
|
||||
kv_cache_dtype: str = 'auto'
|
||||
seed: int = 0
|
||||
max_model_len: Optional[int] = None
|
||||
worker_use_ray: bool = False
|
||||
pipeline_parallel_size: int = 1
|
||||
tensor_parallel_size: int = 1
|
||||
max_parallel_loading_workers: Optional[int] = None
|
||||
block_size: int = 16
|
||||
swap_space: int = 4 # GiB
|
||||
gpu_memory_utilization: float = 0.90
|
||||
max_num_batched_tokens: Optional[int] = None
|
||||
max_num_seqs: int = 256
|
||||
max_paddings: int = 256
|
||||
disable_log_stats: bool = False
|
||||
revision: Optional[str] = None
|
||||
code_revision: Optional[str] = None
|
||||
tokenizer_revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
enforce_eager: bool = False
|
||||
max_context_len_to_capture: int = 8192
|
||||
disable_custom_all_reduce: bool = False
|
||||
enable_lora: bool = False
|
||||
max_loras: int = 1
|
||||
max_lora_rank: int = 16
|
||||
lora_extra_vocab_size: int = 256
|
||||
lora_dtype = 'auto'
|
||||
max_cpu_loras: Optional[int] = None
|
||||
device: str = 'auto'
|
||||
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = self.model
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
|
||||
# NOTE: If you update any of the arguments below, please also
|
||||
# make sure to update docs/source/models/engine_args.rst
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default='facebook/opt-125m',
|
||||
help='name or path of the huggingface model to use')
|
||||
parser.add_argument(
|
||||
'--tokenizer',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer,
|
||||
help='name or path of the huggingface tokenizer to use')
|
||||
parser.add_argument(
|
||||
'--revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the specific model version to use. It can be a branch '
|
||||
'name, a tag name, or a commit id. If unspecified, will use '
|
||||
'the default version.')
|
||||
parser.add_argument(
|
||||
'--code-revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the specific revision to use for the model code on '
|
||||
'Hugging Face Hub. It can be a branch name, a tag name, or a '
|
||||
'commit id. If unspecified, will use the default version.')
|
||||
parser.add_argument(
|
||||
'--tokenizer-revision',
|
||||
type=str,
|
||||
default=None,
|
||||
help='the specific tokenizer version to use. It can be a branch '
|
||||
'name, a tag name, or a commit id. If unspecified, will use '
|
||||
'the default version.')
|
||||
parser.add_argument('--tokenizer-mode',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer_mode,
|
||||
choices=['auto', 'slow'],
|
||||
help='tokenizer mode. "auto" will use the fast '
|
||||
'tokenizer if available, and "slow" will '
|
||||
'always use the slow tokenizer.')
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='trust remote code from huggingface')
|
||||
parser.add_argument('--download-dir',
|
||||
type=str,
|
||||
default=EngineArgs.download_dir,
|
||||
help='directory to download and load the weights, '
|
||||
'default to the default cache dir of '
|
||||
'huggingface')
|
||||
parser.add_argument(
|
||||
'--load-format',
|
||||
type=str,
|
||||
default=EngineArgs.load_format,
|
||||
choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'],
|
||||
help='The format of the model weights to load. '
|
||||
'"auto" will try to load the weights in the safetensors format '
|
||||
'and fall back to the pytorch bin format if safetensors format '
|
||||
'is not available. '
|
||||
'"pt" will load the weights in the pytorch bin format. '
|
||||
'"safetensors" will load the weights in the safetensors format. '
|
||||
'"npcache" will load the weights in pytorch format and store '
|
||||
'a numpy cache to speed up the loading. '
|
||||
'"dummy" will initialize the weights with random values, '
|
||||
'which is mainly for profiling.')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default=EngineArgs.dtype,
|
||||
choices=[
|
||||
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
|
||||
],
|
||||
help='data type for model weights and activations. '
|
||||
'The "auto" option will use FP16 precision '
|
||||
'for FP32 and FP16 models, and BF16 precision '
|
||||
'for BF16 models.')
|
||||
parser.add_argument(
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
choices=['auto', 'fp8_e5m2'],
|
||||
default=EngineArgs.kv_cache_dtype,
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. Note FP8 is not supported when cuda version is '
|
||||
'lower than 11.8.')
|
||||
parser.add_argument('--max-model-len',
|
||||
type=int,
|
||||
default=EngineArgs.max_model_len,
|
||||
help='model context length. If unspecified, '
|
||||
'will be automatically derived from the model.')
|
||||
# Parallel arguments
|
||||
parser.add_argument('--worker-use-ray',
|
||||
action='store_true',
|
||||
help='use Ray for distributed serving, will be '
|
||||
'automatically set when using more than 1 GPU')
|
||||
parser.add_argument('--pipeline-parallel-size',
|
||||
'-pp',
|
||||
type=int,
|
||||
default=EngineArgs.pipeline_parallel_size,
|
||||
help='number of pipeline stages')
|
||||
parser.add_argument('--tensor-parallel-size',
|
||||
'-tp',
|
||||
type=int,
|
||||
default=EngineArgs.tensor_parallel_size,
|
||||
help='number of tensor parallel replicas')
|
||||
parser.add_argument(
|
||||
'--max-parallel-loading-workers',
|
||||
type=int,
|
||||
default=EngineArgs.max_parallel_loading_workers,
|
||||
help='load model sequentially in multiple batches, '
|
||||
'to avoid RAM OOM when using tensor '
|
||||
'parallel and large models')
|
||||
# KV cache arguments
|
||||
parser.add_argument('--block-size',
|
||||
type=int,
|
||||
default=EngineArgs.block_size,
|
||||
choices=[16],
|
||||
help='token block size')
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
default=EngineArgs.seed,
|
||||
help='random seed')
|
||||
parser.add_argument('--swap-space',
|
||||
type=int,
|
||||
default=EngineArgs.swap_space,
|
||||
help='CPU swap space size (GiB) per GPU')
|
||||
parser.add_argument(
|
||||
'--gpu-memory-utilization',
|
||||
type=float,
|
||||
default=EngineArgs.gpu_memory_utilization,
|
||||
help='the fraction of GPU memory to be used for '
|
||||
'the model executor, which can range from 0 to 1.'
|
||||
'If unspecified, will use the default value of 0.9.')
|
||||
parser.add_argument('--max-num-batched-tokens',
|
||||
type=int,
|
||||
default=EngineArgs.max_num_batched_tokens,
|
||||
help='maximum number of batched tokens per '
|
||||
'iteration')
|
||||
parser.add_argument('--max-num-seqs',
|
||||
type=int,
|
||||
default=EngineArgs.max_num_seqs,
|
||||
help='maximum number of sequences per iteration')
|
||||
parser.add_argument('--max-paddings',
|
||||
type=int,
|
||||
default=EngineArgs.max_paddings,
|
||||
help='maximum number of paddings in a batch')
|
||||
parser.add_argument('--disable-log-stats',
|
||||
action='store_true',
|
||||
help='disable logging statistics')
|
||||
# Quantization settings.
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
type=str,
|
||||
choices=['awq', 'gptq', 'squeezellm', 'smoothquant',None],
|
||||
default=EngineArgs.quantization,
|
||||
help='Method used to quantize the weights. If '
|
||||
'None, we first check the `quantization_config` '
|
||||
'attribute in the model config file. If that is '
|
||||
'None, we assume the model weights are not '
|
||||
'quantized and use `dtype` to determine the data '
|
||||
'type of the weights.')
|
||||
parser.add_argument('--enforce-eager',
|
||||
action='store_true',
|
||||
help='Always use eager-mode PyTorch. If False, '
|
||||
'will use eager mode and CUDA graph in hybrid '
|
||||
'for maximal performance and flexibility.')
|
||||
parser.add_argument('--max-context-len-to-capture',
|
||||
type=int,
|
||||
default=EngineArgs.max_context_len_to_capture,
|
||||
help='maximum context length covered by CUDA '
|
||||
'graphs. When a sequence has context length '
|
||||
'larger than this, we fall back to eager mode.')
|
||||
parser.add_argument('--disable-custom-all-reduce',
|
||||
action='store_true',
|
||||
default=EngineArgs.disable_custom_all_reduce,
|
||||
help='See ParallelConfig')
|
||||
# LoRA related configs
|
||||
parser.add_argument('--enable-lora',
|
||||
action='store_true',
|
||||
help='If True, enable handling of LoRA adapters.')
|
||||
parser.add_argument('--max-loras',
|
||||
type=int,
|
||||
default=EngineArgs.max_loras,
|
||||
help='Max number of LoRAs in a single batch.')
|
||||
parser.add_argument('--max-lora-rank',
|
||||
type=int,
|
||||
default=EngineArgs.max_lora_rank,
|
||||
help='Max LoRA rank.')
|
||||
parser.add_argument(
|
||||
'--lora-extra-vocab-size',
|
||||
type=int,
|
||||
default=EngineArgs.lora_extra_vocab_size,
|
||||
help=('Maximum size of extra vocabulary that can be '
|
||||
'present in a LoRA adapter (added to the base '
|
||||
'model vocabulary).'))
|
||||
parser.add_argument(
|
||||
'--lora-dtype',
|
||||
type=str,
|
||||
default=EngineArgs.lora_dtype,
|
||||
choices=['auto', 'float16', 'bfloat16', 'float32'],
|
||||
help=('Data type for LoRA. If auto, will default to '
|
||||
'base model dtype.'))
|
||||
parser.add_argument(
|
||||
'--max-cpu-loras',
|
||||
type=int,
|
||||
default=EngineArgs.max_cpu_loras,
|
||||
help=('Maximum number of LoRAs to store in CPU memory. '
|
||||
'Must be >= than max_num_seqs. '
|
||||
'Defaults to max_num_seqs.'))
|
||||
parser.add_argument("--device",
|
||||
type=str,
|
||||
default=EngineArgs.device,
|
||||
choices=["auto", "cuda", "neuron"],
|
||||
help='Device type for vLLM execution.')
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs':
|
||||
# Get the list of attributes of this dataclass.
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
# Set the attributes from the parsed arguments.
|
||||
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
return engine_args
|
||||
|
||||
def create_engine_configs(
|
||||
self,
|
||||
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig,
|
||||
DeviceConfig, Optional[LoRAConfig]]:
|
||||
device_config = DeviceConfig(self.device)
|
||||
model_config = ModelConfig(
|
||||
self.model, self.tokenizer, self.tokenizer_mode,
|
||||
self.trust_remote_code, self.download_dir, self.load_format,
|
||||
self.dtype, self.seed, self.revision, self.code_revision,
|
||||
self.tokenizer_revision, self.max_model_len, self.quantization,
|
||||
self.enforce_eager, self.max_context_len_to_capture)
|
||||
cache_config = CacheConfig(self.block_size,
|
||||
self.gpu_memory_utilization,
|
||||
self.swap_space, self.kv_cache_dtype,
|
||||
model_config.get_sliding_window())
|
||||
parallel_config = ParallelConfig(self.pipeline_parallel_size,
|
||||
self.tensor_parallel_size,
|
||||
self.worker_use_ray,
|
||||
self.max_parallel_loading_workers,
|
||||
self.disable_custom_all_reduce)
|
||||
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
|
||||
self.max_num_seqs,
|
||||
model_config.max_model_len,
|
||||
self.max_paddings)
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=self.max_lora_rank,
|
||||
max_loras=self.max_loras,
|
||||
lora_extra_vocab_size=self.lora_extra_vocab_size,
|
||||
lora_dtype=self.lora_dtype,
|
||||
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
|
||||
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
|
||||
return (model_config, cache_config, parallel_config, scheduler_config,
|
||||
device_config, lora_config)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AsyncEngineArgs(EngineArgs):
|
||||
"""Arguments for asynchronous vLLM engine."""
|
||||
engine_use_ray: bool = False
|
||||
disable_log_requests: bool = False
|
||||
max_log_len: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
||||
parser = EngineArgs.add_cli_args(parser)
|
||||
parser.add_argument('--engine-use-ray',
|
||||
action='store_true',
|
||||
help='use Ray to start the LLM engine in a '
|
||||
'separate process as the server process.')
|
||||
parser.add_argument('--disable-log-requests',
|
||||
action='store_true',
|
||||
help='disable logging requests')
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log. '
|
||||
'Default: unlimited.')
|
||||
return parser
|
||||
689
vllm/engine/async_llm_engine.py
Normal file
689
vllm/engine/async_llm_engine.py
Normal file
@@ -0,0 +1,689 @@
|
||||
import asyncio
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
|
||||
Union, AsyncIterator)
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.ray_utils import initialize_cluster, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class AsyncEngineDeadError(RuntimeError):
|
||||
pass
|
||||
|
||||
|
||||
def _raise_exception_on_finish(task: asyncio.Task,
|
||||
request_tracker: "RequestTracker") -> None:
|
||||
msg = ("Task finished unexpectedly. This should never happen! "
|
||||
"Please open an issue on Github.")
|
||||
try:
|
||||
try:
|
||||
task.result()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except Exception as exc:
|
||||
raise AsyncEngineDeadError(
|
||||
msg + " See stack trace above for the actual cause.") from exc
|
||||
raise AsyncEngineDeadError(msg)
|
||||
except Exception as exc:
|
||||
request_tracker.propagate_exception(exc)
|
||||
raise exc
|
||||
|
||||
|
||||
class AsyncStream:
|
||||
"""A stream of RequestOutputs for a request that can be
|
||||
iterated over asynchronously."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
self._queue = asyncio.Queue()
|
||||
self._finished = False
|
||||
|
||||
def put(self, item: RequestOutput) -> None:
|
||||
if self._finished:
|
||||
return
|
||||
self._queue.put_nowait(item)
|
||||
|
||||
def finish(self) -> None:
|
||||
self._queue.put_nowait(StopAsyncIteration())
|
||||
self._finished = True
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
return self._finished
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self) -> RequestOutput:
|
||||
result = await self._queue.get()
|
||||
if isinstance(result, Exception):
|
||||
raise result
|
||||
return result
|
||||
|
||||
|
||||
class RequestTracker:
|
||||
"""Synchronous abstraction for tracking requests."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._request_streams: Dict[str, AsyncStream] = {}
|
||||
self._finished_requests: asyncio.Queue[str] = asyncio.Queue()
|
||||
self._new_requests: asyncio.Queue[Tuple[AsyncStream,
|
||||
dict]] = asyncio.Queue()
|
||||
self.new_requests_event = None
|
||||
|
||||
def __contains__(self, item):
|
||||
return item in self._request_streams
|
||||
|
||||
def init_event(self):
|
||||
self.new_requests_event = asyncio.Event()
|
||||
|
||||
def propagate_exception(self,
|
||||
exc: Exception,
|
||||
request_id: Optional[str] = None) -> None:
|
||||
"""Propagate an exception to request streams
|
||||
(all if request_id is None)."""
|
||||
if request_id is not None:
|
||||
self._request_streams[request_id].put(exc)
|
||||
else:
|
||||
for stream in self._request_streams.values():
|
||||
stream.put(exc)
|
||||
|
||||
def process_request_output(self,
|
||||
request_output: RequestOutput,
|
||||
*,
|
||||
verbose: bool = False) -> None:
|
||||
"""Process a request output from the engine."""
|
||||
request_id = request_output.request_id
|
||||
|
||||
self._request_streams[request_id].put(request_output)
|
||||
if request_output.finished:
|
||||
if verbose:
|
||||
logger.info(f"Finished request {request_id}.")
|
||||
self.abort_request(request_id)
|
||||
|
||||
def add_request(self, request_id: str,
|
||||
**engine_add_request_kwargs) -> AsyncStream:
|
||||
"""Add a request to be sent to the engine on the next background
|
||||
loop iteration."""
|
||||
if request_id in self._request_streams:
|
||||
raise KeyError(f"Request {request_id} already exists.")
|
||||
|
||||
stream = AsyncStream(request_id)
|
||||
self._new_requests.put_nowait((stream, {
|
||||
"request_id": request_id,
|
||||
**engine_add_request_kwargs
|
||||
}))
|
||||
|
||||
self.new_requests_event.set()
|
||||
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||
"""Abort a request during next background loop iteration."""
|
||||
if verbose:
|
||||
logger.info(f"Aborted request {request_id}.")
|
||||
|
||||
self._finished_requests.put_nowait(request_id)
|
||||
|
||||
if request_id not in self._request_streams or self._request_streams[
|
||||
request_id].finished:
|
||||
# The request has already finished or been aborted.
|
||||
return
|
||||
|
||||
self._request_streams[request_id].finish()
|
||||
|
||||
def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]:
|
||||
"""Get the new requests and finished requests to be
|
||||
sent to the engine."""
|
||||
new_requests: List[Dict] = []
|
||||
finished_requests: Set[str] = set()
|
||||
|
||||
while not self._finished_requests.empty():
|
||||
request_id = self._finished_requests.get_nowait()
|
||||
finished_requests.add(request_id)
|
||||
self._request_streams.pop(request_id, None)
|
||||
|
||||
while not self._new_requests.empty():
|
||||
stream, new_request = self._new_requests.get_nowait()
|
||||
if stream.request_id in finished_requests:
|
||||
# The request has already been aborted.
|
||||
stream.finish()
|
||||
continue
|
||||
self._request_streams[stream.request_id] = stream
|
||||
new_requests.append(new_request)
|
||||
|
||||
self.new_requests_event.clear()
|
||||
|
||||
return new_requests, finished_requests
|
||||
|
||||
async def wait_for_new_requests(self):
|
||||
await self.new_requests_event.wait()
|
||||
|
||||
|
||||
class _AsyncLLMEngine(LLMEngine):
|
||||
"""Extension of LLMEngine to add async methods."""
|
||||
|
||||
async def step_async(self) -> List[RequestOutput]:
|
||||
"""Performs one decoding iteration and returns newly generated results.
|
||||
The workers are ran asynchronously if possible.
|
||||
|
||||
This function performs one decoding iteration of the engine. It first
|
||||
schedules the sequences to be executed in the next iteration and the
|
||||
token blocks to be swapped in/out/copy. Then, it executes the model
|
||||
and updates the scheduler with the model outputs. Finally, it decodes
|
||||
the sequences and returns the newly generated results.
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
|
||||
# Execute the model.
|
||||
output = (await self._run_workers_async(
|
||||
"execute_model",
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
|
||||
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
|
||||
blocks_to_copy=scheduler_outputs.blocks_to_copy,
|
||||
)) if not scheduler_outputs.is_empty() else []
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
|
||||
# TODO align
|
||||
"""
|
||||
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
|
||||
|
||||
if not scheduler_outputs.is_empty():
|
||||
# Execute the model.
|
||||
all_outputs = await self._run_workers_async(
|
||||
"execute_model",
|
||||
driver_kwargs={
|
||||
"seq_group_metadata_list": seq_group_metadata_list,
|
||||
"blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in,
|
||||
"blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out,
|
||||
"blocks_to_copy": scheduler_outputs.blocks_to_copy,
|
||||
})
|
||||
|
||||
# Only the driver worker returns the sampling results.
|
||||
output = all_outputs[0]
|
||||
else:
|
||||
output = []
|
||||
|
||||
return self._process_model_outputs(output, scheduler_outputs)
|
||||
"""
|
||||
|
||||
async def encode_request_async(
|
||||
self,
|
||||
request_id: str, # pylint: disable=unused-argument
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
):
|
||||
if prompt_token_ids is None:
|
||||
assert prompt is not None
|
||||
prompt_token_ids = await self.tokenizer.encode_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
lora_request=lora_request)
|
||||
return prompt_token_ids
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
"not enabled!")
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
prompt_token_ids = await self.encode_request_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request)
|
||||
|
||||
return self.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos,
|
||||
)
|
||||
|
||||
async def _run_workers_async(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
get_all_outputs: bool = False,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers."""
|
||||
coros = []
|
||||
for worker in self.workers:
|
||||
if self.parallel_config.worker_use_ray:
|
||||
coros.append(
|
||||
worker.execute_method.remote(method, *args, **kwargs))
|
||||
else:
|
||||
executor = getattr(worker, method)
|
||||
coros.append(asyncio.get_event_loop().run_in_executor(
|
||||
None, partial(executor, *args, **kwargs)))
|
||||
|
||||
all_outputs = await asyncio.gather(*coros)
|
||||
|
||||
if get_all_outputs:
|
||||
return all_outputs
|
||||
|
||||
# Make sure all workers have the same results.
|
||||
output = all_outputs[0]
|
||||
for other_output in all_outputs[1:]:
|
||||
assert output == other_output
|
||||
return output
|
||||
|
||||
# TODO align
|
||||
"""
|
||||
async def _run_workers_async(
|
||||
self,
|
||||
method: str,
|
||||
*args,
|
||||
driver_args: Optional[List[Any]] = None,
|
||||
driver_kwargs: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
coros = []
|
||||
|
||||
if driver_args is None:
|
||||
driver_args = args
|
||||
if driver_kwargs is None:
|
||||
driver_kwargs = kwargs
|
||||
|
||||
# Run the driver worker asynchronously.
|
||||
driver_executor = getattr(self.driver_worker, method)
|
||||
coros.append(asyncio.get_event_loop().run_in_executor(
|
||||
None, partial(driver_executor, *driver_args, **driver_kwargs)))
|
||||
|
||||
# Run the ray workers asynchronously.
|
||||
for worker in self.workers:
|
||||
coros.append(worker.execute_method.remote(method, *args, **kwargs))
|
||||
|
||||
all_outputs = await asyncio.gather(*coros)
|
||||
return all_outputs
|
||||
"""
|
||||
|
||||
|
||||
class AsyncLLMEngine:
|
||||
"""An asynchronous wrapper for LLMEngine.
|
||||
|
||||
This class is used to wrap the LLMEngine class to make it asynchronous. It
|
||||
uses asyncio to create a background loop that keeps processing incoming
|
||||
requests. The LLMEngine is kicked by the generate method when there
|
||||
are requests in the waiting queue. The generate method yields the outputs
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
NOTE: For the comprehensive list of arguments, see `LLMEngine`.
|
||||
|
||||
Args:
|
||||
worker_use_ray: Whether to use Ray for model workers. Required for
|
||||
distributed execution. Should be the same as
|
||||
`parallel_config.worker_use_ray`.
|
||||
engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
|
||||
async frontend will be executed in a separate process as the
|
||||
model workers.
|
||||
log_requests: Whether to log the requests.
|
||||
max_log_len: Maximum number of prompt characters or prompt ID numbers
|
||||
being printed in log.
|
||||
start_engine_loop: If True, the background task to run the engine
|
||||
will be automatically started in the generate call.
|
||||
*args: Arguments for LLMEngine.
|
||||
*kwargs: Arguments for LLMEngine.
|
||||
"""
|
||||
|
||||
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
||||
|
||||
def __init__(self,
|
||||
worker_use_ray: bool,
|
||||
engine_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
max_log_len: Optional[int] = None,
|
||||
start_engine_loop: bool = True,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.engine_use_ray = engine_use_ray
|
||||
self.log_requests = log_requests
|
||||
self.max_log_len = max_log_len
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
self.background_loop = None
|
||||
# We need to keep a reference to unshielded
|
||||
# task as well to prevent it from being garbage
|
||||
# collected
|
||||
self._background_loop_unshielded = None
|
||||
self.start_engine_loop = start_engine_loop
|
||||
self._request_tracker = RequestTracker()
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
return (self.background_loop is not None
|
||||
and not self.background_loop.done())
|
||||
|
||||
def get_tokenizer(self):
|
||||
return self.engine.tokenizer.tokenizer
|
||||
|
||||
def start_background_loop(self) -> None:
|
||||
"""Start the background loop."""
|
||||
if self.is_running:
|
||||
raise RuntimeError("Background loop is already running.")
|
||||
self._request_tracker.init_event()
|
||||
|
||||
self._background_loop_unshielded = asyncio.get_event_loop(
|
||||
).create_task(self.run_engine_loop())
|
||||
self._background_loop_unshielded.add_done_callback(
|
||||
partial(_raise_exception_on_finish,
|
||||
request_tracker=self._request_tracker))
|
||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||
|
||||
def _init_engine(self, *args,
|
||||
**kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]:
|
||||
if not self.engine_use_ray:
|
||||
engine_class = self._engine_class
|
||||
elif self.worker_use_ray:
|
||||
engine_class = ray.remote(num_cpus=0)(self._engine_class).remote
|
||||
else:
|
||||
# FIXME(woosuk): This is a bit hacky. Be careful when changing the
|
||||
# order of the arguments.
|
||||
cache_config = args[1]
|
||||
parallel_config = args[2]
|
||||
if parallel_config.tensor_parallel_size == 1:
|
||||
num_gpus = cache_config.gpu_memory_utilization
|
||||
else:
|
||||
num_gpus = 1
|
||||
engine_class = ray.remote(num_gpus=num_gpus)(
|
||||
self._engine_class).remote
|
||||
return engine_class(*args, **kwargs)
|
||||
|
||||
async def engine_step(self) -> bool:
|
||||
"""Kick the engine to process the waiting requests.
|
||||
|
||||
Returns True if there are in-progress requests."""
|
||||
|
||||
new_requests, finished_requests = (
|
||||
self._request_tracker.get_new_and_finished_requests())
|
||||
|
||||
for new_request in new_requests:
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
# TODO: Maybe add add_request_batch to reduce Ray overhead
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(**new_request)
|
||||
else:
|
||||
await self.engine.add_request_async(**new_request)
|
||||
|
||||
if finished_requests:
|
||||
await self._engine_abort(finished_requests)
|
||||
|
||||
if self.engine_use_ray:
|
||||
request_outputs = await self.engine.step.remote()
|
||||
else:
|
||||
request_outputs = await self.engine.step_async()
|
||||
|
||||
# Put the outputs into the corresponding streams.
|
||||
for request_output in request_outputs:
|
||||
self._request_tracker.process_request_output(
|
||||
request_output, verbose=self.log_requests)
|
||||
|
||||
return len(request_outputs) > 0
|
||||
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
if self.engine_use_ray:
|
||||
await self.engine.abort_request.remote(request_ids)
|
||||
else:
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
# Initialize the RequestTracker here so it uses the right event loop.
|
||||
has_requests_in_progress = False
|
||||
while True:
|
||||
if not has_requests_in_progress:
|
||||
await self._request_tracker.wait_for_new_requests()
|
||||
has_requests_in_progress = await self.engine_step()
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def add_request(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
shortened_prompt = prompt
|
||||
shortened_token_ids = prompt_token_ids
|
||||
if self.max_log_len is not None:
|
||||
if shortened_prompt is not None:
|
||||
shortened_prompt = shortened_prompt[:self.max_log_len]
|
||||
if shortened_token_ids is not None:
|
||||
shortened_token_ids = shortened_token_ids[:self.
|
||||
max_log_len]
|
||||
logger.info(f"Received request {request_id}: "
|
||||
f"prompt: {shortened_prompt!r}, "
|
||||
f"prefix_pos: {prefix_pos},"
|
||||
f"sampling_params: {sampling_params}, "
|
||||
f"prompt_token_ids: {shortened_token_ids}, "
|
||||
f"lora_request: {lora_request}.")
|
||||
|
||||
if not self.is_running:
|
||||
if self.start_engine_loop:
|
||||
self.start_background_loop()
|
||||
else:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
if self.engine_use_ray:
|
||||
prompt_token_ids = await self.engine.encode_request_async.remote(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request)
|
||||
else:
|
||||
prompt_token_ids = await self.engine.encode_request_async(
|
||||
request_id=request_id,
|
||||
prompt=prompt,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
lora_request=lora_request)
|
||||
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos)
|
||||
|
||||
return stream
|
||||
|
||||
async def generate(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
prompt_token_ids: Optional[List[int]] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Generate outputs for a request.
|
||||
|
||||
Generate outputs for a request. This method is a coroutine. It adds the
|
||||
request into the waiting queue of the LLMEngine and streams the outputs
|
||||
from the LLMEngine to the caller.
|
||||
|
||||
Args:
|
||||
prompt: The prompt string. Can be None if prompt_token_ids is
|
||||
provided.
|
||||
sampling_params: The sampling parameters of the request.
|
||||
request_id: The unique id of the request.
|
||||
prompt_token_ids: The token IDs of the prompt. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
prefix_pos: If not None, we use the given position as the prefix
|
||||
position for each prompt. We will cache the prefix's KV
|
||||
cache and reuse it for the next request with the same prefix.
|
||||
This is an experimental feature, and may be replaced with
|
||||
automatic prefix caching in the future.
|
||||
|
||||
Yields:
|
||||
The output `RequestOutput` objects from the LLMEngine for the
|
||||
request.
|
||||
|
||||
Details:
|
||||
- If the engine is not running, start the background loop,
|
||||
which iteratively invokes
|
||||
:meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step`
|
||||
to process the waiting requests.
|
||||
- Add the request to the engine's `RequestTracker`.
|
||||
On the next background loop, this request will be sent to
|
||||
the underlying engine.
|
||||
Also, a corresponding `AsyncStream` will be created.
|
||||
- Wait for the request outputs from `AsyncStream` and yield them.
|
||||
|
||||
Example:
|
||||
>>> # Please refer to entrypoints/api_server.py for
|
||||
>>> # the complete example.
|
||||
>>>
|
||||
>>> # initialize the engine and the example input
|
||||
>>> engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
>>> example_input = {
|
||||
>>> "prompt": "What is LLM?",
|
||||
>>> "stream": False, # assume the non-streaming case
|
||||
>>> "temperature": 0.0,
|
||||
>>> "request_id": 0,
|
||||
>>> }
|
||||
>>>
|
||||
>>> # start the generation
|
||||
>>> results_generator = engine.generate(
|
||||
>>> example_input["prompt"],
|
||||
>>> SamplingParams(temperature=example_input["temperature"]),
|
||||
>>> example_input["request_id"])
|
||||
>>>
|
||||
>>> # get the results
|
||||
>>> final_output = None
|
||||
>>> async for request_output in results_generator:
|
||||
>>> if await request.is_disconnected():
|
||||
>>> # Abort the request if the client disconnects.
|
||||
>>> await engine.abort(request_id)
|
||||
>>> # Return or raise an error
|
||||
>>> ...
|
||||
>>> final_output = request_output
|
||||
>>>
|
||||
>>> # Process and return the final output
|
||||
>>> ...
|
||||
"""
|
||||
# Preprocess the request.
|
||||
# This should not be used for logging, as it is monotonic time.
|
||||
arrival_time = time.monotonic()
|
||||
|
||||
try:
|
||||
stream = await self.add_request(
|
||||
request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos,
|
||||
)
|
||||
|
||||
async for request_output in stream:
|
||||
yield request_output
|
||||
except (Exception, asyncio.CancelledError) as e:
|
||||
# If there is an exception or coroutine is cancelled, abort the
|
||||
# request.
|
||||
self._abort(request_id)
|
||||
raise e
|
||||
|
||||
async def abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
Abort a submitted request. If the request is finished or not found,
|
||||
this method will be a no-op.
|
||||
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
if not self.is_running:
|
||||
raise AsyncEngineDeadError(
|
||||
"Background loop is not running. If it was running, "
|
||||
"inspect the output to find the stacktrace of the "
|
||||
"error that caused the background loop to stop "
|
||||
"(AsyncEngineDeadError).")
|
||||
|
||||
return self._abort(request_id)
|
||||
|
||||
def _abort(self, request_id: str) -> None:
|
||||
"""Abort a request.
|
||||
|
||||
Abort a submitted request. If the request is finished or not found,
|
||||
this method will be a no-op.
|
||||
|
||||
Args:
|
||||
request_id: The unique id of the request.
|
||||
"""
|
||||
self._request_tracker.abort_request(request_id,
|
||||
verbose=self.log_requests)
|
||||
|
||||
async def get_model_config(self) -> ModelConfig:
|
||||
"""Get the model configuration of the vLLM engine."""
|
||||
if self.engine_use_ray:
|
||||
return await self.engine.get_model_config.remote()
|
||||
else:
|
||||
return self.engine.get_model_config()
|
||||
|
||||
@classmethod
|
||||
def from_engine_args(cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = True) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_configs = engine_args.create_engine_configs()
|
||||
parallel_config = engine_configs[2]
|
||||
# Initialize the cluster.
|
||||
placement_group = initialize_cluster(parallel_config,
|
||||
engine_args.engine_use_ray)
|
||||
# Create the async LLM engine.
|
||||
engine = cls(parallel_config.worker_use_ray,
|
||||
engine_args.engine_use_ray,
|
||||
*engine_configs,
|
||||
placement_group,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
max_log_len=engine_args.max_log_len,
|
||||
start_engine_loop=start_engine_loop)
|
||||
return engine
|
||||
|
||||
async def do_log_stats(self) -> None:
|
||||
if self.engine_use_ray:
|
||||
await self.engine.do_log_stats.remote()
|
||||
else:
|
||||
self.engine.do_log_stats()
|
||||
1209
vllm/engine/llm_engine.py
Normal file
1209
vllm/engine/llm_engine.py
Normal file
File diff suppressed because it is too large
Load Diff
225
vllm/engine/metrics.py
Normal file
225
vllm/engine/metrics.py
Normal file
@@ -0,0 +1,225 @@
|
||||
from vllm.logger import init_logger
|
||||
from prometheus_client import Counter, Gauge, Histogram, Info, REGISTRY, disable_created_metrics
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
from typing import Dict, List
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
disable_created_metrics()
|
||||
|
||||
# The begin-* and end* here are used by the documentation generator
|
||||
# to extract the metrics definitions.
|
||||
|
||||
|
||||
# begin-metrics-definitions
|
||||
class Metrics:
|
||||
|
||||
def __init__(self, labelnames: List[str]):
|
||||
# Unregister any existing vLLM collectors
|
||||
for collector in list(REGISTRY._collector_to_names):
|
||||
if hasattr(collector, "_name") and "vllm" in collector._name:
|
||||
REGISTRY.unregister(collector)
|
||||
|
||||
self.info_cache_config = Info(
|
||||
name='vllm:cache_config',
|
||||
documentation='information of cache_config')
|
||||
|
||||
# System stats
|
||||
self.gauge_scheduler_running = Gauge(
|
||||
name="vllm:num_requests_running",
|
||||
documentation="Number of requests currently running on GPU.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_scheduler_swapped = Gauge(
|
||||
name="vllm:num_requests_swapped",
|
||||
documentation="Number of requests swapped to CPU.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_scheduler_waiting = Gauge(
|
||||
name="vllm:num_requests_waiting",
|
||||
documentation="Number of requests waiting to be processed.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_gpu_cache_usage = Gauge(
|
||||
name="vllm:gpu_cache_usage_perc",
|
||||
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_cpu_cache_usage = Gauge(
|
||||
name="vllm:cpu_cache_usage_perc",
|
||||
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
|
||||
labelnames=labelnames)
|
||||
|
||||
# Raw stats from last model iteration
|
||||
self.counter_prompt_tokens = Counter(
|
||||
name="vllm:prompt_tokens_total",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames)
|
||||
self.counter_generation_tokens = Counter(
|
||||
name="vllm:generation_tokens_total",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames)
|
||||
self.histogram_time_to_first_token = Histogram(
|
||||
name="vllm:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[
|
||||
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
|
||||
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
|
||||
])
|
||||
self.histogram_time_per_output_token = Histogram(
|
||||
name="vllm:time_per_output_token_seconds",
|
||||
documentation="Histogram of time per output token in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[
|
||||
0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75,
|
||||
1.0, 2.5
|
||||
])
|
||||
self.histogram_e2e_request_latency = Histogram(
|
||||
name="vllm:e2e_request_latency_seconds",
|
||||
documentation="Histogram of end to end request latency in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
|
||||
|
||||
# Legacy metrics
|
||||
self.gauge_avg_prompt_throughput = Gauge(
|
||||
name="vllm:avg_prompt_throughput_toks_per_s",
|
||||
documentation="Average prefill throughput in tokens/s.",
|
||||
labelnames=labelnames,
|
||||
)
|
||||
self.gauge_avg_generation_throughput = Gauge(
|
||||
name="vllm:avg_generation_throughput_toks_per_s",
|
||||
documentation="Average generation throughput in tokens/s.",
|
||||
labelnames=labelnames,
|
||||
)
|
||||
|
||||
|
||||
# end-metrics-definitions
|
||||
|
||||
|
||||
@dataclass
|
||||
class Stats:
|
||||
"""Created by LLMEngine for use by StatLogger."""
|
||||
now: float
|
||||
|
||||
# System stats.
|
||||
num_running: int
|
||||
num_waiting: int
|
||||
num_swapped: int
|
||||
gpu_cache_usage: float
|
||||
cpu_cache_usage: float
|
||||
|
||||
# Raw stats from last model iteration.
|
||||
num_prompt_tokens: int
|
||||
num_generation_tokens: int
|
||||
time_to_first_tokens: List[float]
|
||||
time_per_output_tokens: List[float]
|
||||
time_e2e_requests: List[float]
|
||||
|
||||
|
||||
class StatLogger:
|
||||
"""StatLogger is used LLMEngine to log to Promethus and Stdout."""
|
||||
|
||||
def __init__(self, local_interval: float, labels: Dict[str, str]) -> None:
|
||||
# Metadata for logging locally.
|
||||
self.last_local_log = time.monotonic()
|
||||
self.local_interval = local_interval
|
||||
|
||||
# Tracked stats over current local logging interval.
|
||||
self.num_prompt_tokens: List[int] = []
|
||||
self.num_generation_tokens: List[int] = []
|
||||
|
||||
# Prometheus metrics
|
||||
self.labels = labels
|
||||
self.metrics = Metrics(labelnames=list(labels.keys()))
|
||||
|
||||
def info(self, type: str, obj: object) -> None:
|
||||
if type == "cache_config":
|
||||
self.metrics.info_cache_config.info(obj.metrics_info())
|
||||
|
||||
def _get_throughput(self, tracked_stats: List[int], now: float) -> float:
|
||||
return float(np.sum(tracked_stats) / (now - self.last_local_log))
|
||||
|
||||
def _local_interval_elapsed(self, now: float) -> bool:
|
||||
elapsed_time = now - self.last_local_log
|
||||
return elapsed_time > self.local_interval
|
||||
|
||||
def _log_prometheus(self, stats: Stats) -> None:
|
||||
# Set system stat gauges.
|
||||
self.metrics.gauge_scheduler_running.labels(**self.labels).set(
|
||||
stats.num_running)
|
||||
self.metrics.gauge_scheduler_swapped.labels(**self.labels).set(
|
||||
stats.num_swapped)
|
||||
self.metrics.gauge_scheduler_waiting.labels(**self.labels).set(
|
||||
stats.num_waiting)
|
||||
self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set(
|
||||
stats.gpu_cache_usage)
|
||||
self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set(
|
||||
stats.cpu_cache_usage)
|
||||
|
||||
# Add to token counters.
|
||||
self.metrics.counter_prompt_tokens.labels(**self.labels).inc(
|
||||
stats.num_prompt_tokens)
|
||||
self.metrics.counter_generation_tokens.labels(**self.labels).inc(
|
||||
stats.num_generation_tokens)
|
||||
|
||||
# Observe request level latencies in histograms.
|
||||
for ttft in stats.time_to_first_tokens:
|
||||
self.metrics.histogram_time_to_first_token.labels(
|
||||
**self.labels).observe(ttft)
|
||||
for tpot in stats.time_per_output_tokens:
|
||||
self.metrics.histogram_time_per_output_token.labels(
|
||||
**self.labels).observe(tpot)
|
||||
for e2e in stats.time_e2e_requests:
|
||||
self.metrics.histogram_e2e_request_latency.labels(
|
||||
**self.labels).observe(e2e)
|
||||
|
||||
def _log_prometheus_interval(self, prompt_throughput: float,
|
||||
generation_throughput: float) -> None:
|
||||
# Logs metrics to prometheus that are computed every logging_interval.
|
||||
# Support legacy gauge metrics that make throughput calculations on the vLLM side.
|
||||
# Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens
|
||||
# Which log raw data and calculate summaries using rate() on the grafana/prometheus side.
|
||||
# See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666
|
||||
self.metrics.gauge_avg_prompt_throughput.labels(
|
||||
**self.labels).set(prompt_throughput)
|
||||
self.metrics.gauge_avg_generation_throughput.labels(
|
||||
**self.labels).set(generation_throughput)
|
||||
|
||||
def log(self, stats: Stats) -> None:
|
||||
"""Called by LLMEngine.
|
||||
Logs to prometheus and tracked stats every iteration.
|
||||
Logs to Stdout every self.local_interval seconds."""
|
||||
|
||||
# Log to prometheus.
|
||||
self._log_prometheus(stats)
|
||||
|
||||
# Save tracked stats for token counters.
|
||||
self.num_prompt_tokens.append(stats.num_prompt_tokens)
|
||||
self.num_generation_tokens.append(stats.num_generation_tokens)
|
||||
|
||||
# Log locally every local_interval seconds.
|
||||
if self._local_interval_elapsed(stats.now):
|
||||
|
||||
# Compute summary metrics for tracked stats (and log them to promethus if applicable).
|
||||
prompt_throughput = self._get_throughput(self.num_prompt_tokens,
|
||||
now=stats.now)
|
||||
generation_throughput = self._get_throughput(
|
||||
self.num_generation_tokens, now=stats.now)
|
||||
self._log_prometheus_interval(
|
||||
prompt_throughput=prompt_throughput,
|
||||
generation_throughput=generation_throughput)
|
||||
|
||||
# Log to stdout.
|
||||
logger.info(
|
||||
f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, "
|
||||
f"Avg generation throughput: {generation_throughput:.1f} tokens/s, "
|
||||
f"Running: {stats.num_running} reqs, "
|
||||
f"Swapped: {stats.num_swapped} reqs, "
|
||||
f"Pending: {stats.num_waiting} reqs, "
|
||||
f"GPU KV cache usage: {stats.gpu_cache_usage * 100:.1f}%, "
|
||||
f"CPU KV cache usage: {stats.cpu_cache_usage * 100:.1f}%")
|
||||
|
||||
# Reset tracked stats for next interval.
|
||||
self.num_prompt_tokens = []
|
||||
self.num_generation_tokens = []
|
||||
self.last_local_log = stats.now
|
||||
157
vllm/engine/ray_utils.py
Normal file
157
vllm/engine/ray_utils.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import pickle
|
||||
|
||||
from typing import Optional, List, Tuple, TYPE_CHECKING
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip, set_cuda_visible_devices, get_ip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import ray
|
||||
|
||||
class RayWorkerVllm:
|
||||
"""Ray wrapper for vllm.worker.Worker, allowing Worker to be
|
||||
lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES."""
|
||||
|
||||
def __init__(self, init_cached_hf_modules=False) -> None:
|
||||
if init_cached_hf_modules:
|
||||
from transformers.dynamic_module_utils import init_hf_modules
|
||||
init_hf_modules()
|
||||
self.worker = None
|
||||
# Since the compiled DAG runs a main execution
|
||||
# in a different thread that calls cuda.set_device.
|
||||
# The flag indicates is set_device is called on
|
||||
# that thread.
|
||||
self.compiled_dag_cuda_device_set = False
|
||||
|
||||
def init_worker(self, worker_init_fn):
|
||||
self.worker = worker_init_fn()
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.worker, name)
|
||||
|
||||
def execute_method(self, method, *args, **kwargs):
|
||||
executor = getattr(self, method)
|
||||
return executor(*args, **kwargs)
|
||||
|
||||
def get_node_ip(self) -> str:
|
||||
return get_ip()
|
||||
|
||||
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
|
||||
node_id = ray.get_runtime_context().get_node_id()
|
||||
gpu_ids = ray.get_gpu_ids()
|
||||
return node_id, gpu_ids
|
||||
|
||||
def set_cuda_visible_devices(self, device_ids) -> None:
|
||||
set_cuda_visible_devices(device_ids)
|
||||
|
||||
def execute_model_compiled_dag_remote(self, ignored):
|
||||
"""Used only when compiled DAG is enabled."""
|
||||
import torch
|
||||
if not self.compiled_dag_cuda_device_set:
|
||||
torch.cuda.set_device(self.worker.device)
|
||||
self.compiled_dag_cuda_device_set = True
|
||||
|
||||
output = self.worker.execute_model()
|
||||
output = pickle.dumps(output)
|
||||
return output
|
||||
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import Ray with {e!r}. "
|
||||
"For distributed inference, please install Ray with "
|
||||
"`pip install ray`.")
|
||||
ray = None
|
||||
RayWorkerVllm = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
|
||||
def initialize_cluster(
|
||||
parallel_config: ParallelConfig,
|
||||
engine_use_ray: bool = False,
|
||||
ray_address: Optional[str] = None,
|
||||
) -> Optional["PlacementGroup"]:
|
||||
"""Initialize the distributed cluster probably with Ray.
|
||||
|
||||
Args:
|
||||
parallel_config: The configurations for parallel execution.
|
||||
engine_use_ray: Whether to use Ray for async engine.
|
||||
ray_address: The address of the Ray cluster. If None, uses
|
||||
the default Ray cluster address.
|
||||
|
||||
Returns:
|
||||
An optional `PlacementGroup`. It includes the specification
|
||||
of the resources for each distributed worker. None if Ray is
|
||||
not used.
|
||||
"""
|
||||
if parallel_config.worker_use_ray or engine_use_ray:
|
||||
if ray is None:
|
||||
raise ImportError(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"serving.")
|
||||
import os
|
||||
enable_head_ray = os.environ.get("ENABLE_HEAD_RAY",None)
|
||||
if enable_head_ray is None:
|
||||
if is_hip():
|
||||
ray.init(address=ray_address,
|
||||
ignore_reinit_error=True,
|
||||
num_gpus=parallel_config.world_size)
|
||||
else:
|
||||
ray.init(address=ray_address,
|
||||
ignore_reinit_error=True,
|
||||
num_gpus=parallel_config.world_size)
|
||||
else:
|
||||
ray.init()
|
||||
# TODO align
|
||||
"""
|
||||
# Connect to a ray cluster.
|
||||
if is_hip():
|
||||
ray.init(address=ray_address,
|
||||
ignore_reinit_error=True,
|
||||
num_gpus=parallel_config.world_size)
|
||||
else:
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
"""
|
||||
|
||||
if not parallel_config.worker_use_ray:
|
||||
assert parallel_config.world_size == 1, (
|
||||
"Ray is required if parallel_config.world_size > 1.")
|
||||
return None
|
||||
|
||||
# Create placement group for worker processes
|
||||
current_placement_group = ray.util.get_current_placement_group()
|
||||
if current_placement_group:
|
||||
# We are in a placement group
|
||||
bundles = current_placement_group.bundle_specs
|
||||
# Verify that we can use the placement group.
|
||||
gpu_bundles = 0
|
||||
for bundle in bundles:
|
||||
bundle_gpus = bundle.get("GPU", 0)
|
||||
if bundle_gpus > 1:
|
||||
raise ValueError(
|
||||
"Placement group bundle cannot have more than 1 GPU.")
|
||||
if bundle_gpus:
|
||||
gpu_bundles += 1
|
||||
if parallel_config.world_size > gpu_bundles:
|
||||
raise ValueError(
|
||||
"The number of required GPUs exceeds the total number of "
|
||||
"available GPUs in the placement group.")
|
||||
else:
|
||||
num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0)
|
||||
if parallel_config.world_size > num_gpus_in_cluster:
|
||||
raise ValueError(
|
||||
"The number of required GPUs exceeds the total number of "
|
||||
"available GPUs in the cluster.")
|
||||
# Create a new placement group
|
||||
placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size)
|
||||
current_placement_group = ray.util.placement_group(
|
||||
placement_group_specs)
|
||||
# Wait until PG is ready - this will block until all
|
||||
# requested resources are available, and will timeout
|
||||
# if they cannot be provisioned.
|
||||
ray.get(current_placement_group.ready(), timeout=1800)
|
||||
|
||||
return current_placement_group
|
||||
0
vllm/entrypoints/__init__.py
Normal file
0
vllm/entrypoints/__init__.py
Normal file
BIN
vllm/entrypoints/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/api_server.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/api_server.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/__pycache__/llm.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/__pycache__/llm.cpython-310.pyc
Normal file
Binary file not shown.
105
vllm/entrypoints/api_server.py
Normal file
105
vllm/entrypoints/api_server.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
NOTE: This API server is used only for demonstrating usage of AsyncEngine and simple performance benchmarks.
|
||||
It is not intended for production use. For production use, we recommend using our OpenAI compatible server.
|
||||
We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
app = FastAPI()
|
||||
engine = None
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(request: Request) -> Response:
|
||||
"""Generate completion for the request.
|
||||
|
||||
The request should be a JSON object with the following fields:
|
||||
- prompt: the prompt to use for the generation.
|
||||
- stream: whether to stream the results or not.
|
||||
- other fields: the sampling parameters (See `SamplingParams` for details).
|
||||
"""
|
||||
request_dict = await request.json()
|
||||
prompt = request_dict.pop("prompt")
|
||||
prefix_pos = request_dict.pop("prefix_pos", None)
|
||||
stream = request_dict.pop("stream", False)
|
||||
sampling_params = SamplingParams(**request_dict)
|
||||
request_id = random_uuid()
|
||||
|
||||
results_generator = engine.generate(prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
prefix_pos=prefix_pos)
|
||||
|
||||
# Streaming case
|
||||
async def stream_results() -> AsyncGenerator[bytes, None]:
|
||||
async for request_output in results_generator:
|
||||
prompt = request_output.prompt
|
||||
text_outputs = [
|
||||
prompt + output.text for output in request_output.outputs
|
||||
]
|
||||
ret = {"text": text_outputs}
|
||||
yield (json.dumps(ret) + "\0").encode("utf-8")
|
||||
|
||||
if stream:
|
||||
return StreamingResponse(stream_results())
|
||||
|
||||
# Non-streaming case
|
||||
final_output = None
|
||||
async for request_output in results_generator:
|
||||
if await request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await engine.abort(request_id)
|
||||
return Response(status_code=499)
|
||||
final_output = request_output
|
||||
|
||||
assert final_output is not None
|
||||
prompt = final_output.prompt
|
||||
text_outputs = [prompt + output.text for output in final_output.outputs]
|
||||
ret = {"text": text_outputs}
|
||||
return JSONResponse(ret)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default=None)
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--ssl-keyfile", type=str, default=None)
|
||||
parser.add_argument("--ssl-certfile", type=str, default=None)
|
||||
parser.add_argument(
|
||||
"--root-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy")
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
app.root_path = args.root_path
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile)
|
||||
220
vllm/entrypoints/llm.py
Normal file
220
vllm/entrypoints/llm.py
Normal file
@@ -0,0 +1,220 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import Counter
|
||||
|
||||
|
||||
class LLM:
|
||||
"""An LLM for generating texts from given prompts and sampling parameters.
|
||||
|
||||
This class includes a tokenizer, a language model (possibly distributed
|
||||
across multiple GPUs), and GPU memory space allocated for intermediate
|
||||
states (aka KV cache). Given a batch of prompts and sampling parameters,
|
||||
this class generates texts from the model, using an intelligent batching
|
||||
mechanism and efficient memory management.
|
||||
|
||||
NOTE: This class is intended to be used for offline inference. For online
|
||||
serving, use the `AsyncLLMEngine` class instead.
|
||||
NOTE: For the comprehensive list of arguments, see `EngineArgs`.
|
||||
|
||||
Args:
|
||||
model: The name or path of a HuggingFace Transformers model.
|
||||
tokenizer: The name or path of a HuggingFace Transformers tokenizer.
|
||||
tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer
|
||||
if available, and "slow" will always use the slow tokenizer.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
tensor_parallel_size: The number of GPUs to use for distributed
|
||||
execution with tensor parallelism.
|
||||
dtype: The data type for the model weights and activations. Currently,
|
||||
we support `float32`, `float16`, and `bfloat16`. If `auto`, we use
|
||||
the `torch_dtype` attribute specified in the model config file.
|
||||
However, if the `torch_dtype` in the config is `float32`, we will
|
||||
use `float16` instead.
|
||||
quantization: The method used to quantize the model weights. Currently,
|
||||
we support "awq", "gptq" and "squeezellm". If None, we first check
|
||||
the `quantization_config` attribute in the model config file. If
|
||||
that is None, we assume the model weights are not quantized and use
|
||||
`dtype` to determine the data type of the weights.
|
||||
revision: The specific model version to use. It can be a branch name,
|
||||
a tag name, or a commit id.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
branch name, a tag name, or a commit id.
|
||||
seed: The seed to initialize the random number generator for sampling.
|
||||
gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to
|
||||
reserve for the model weights, activations, and KV cache. Higher
|
||||
values will increase the KV cache size and thus improve the model's
|
||||
throughput. However, if the value is too high, it may cause out-of-
|
||||
memory (OOM) errors.
|
||||
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
|
||||
This can be used for temporarily storing the states of the requests
|
||||
when their `best_of` sampling parameters are larger than 1. If all
|
||||
requests will have `best_of=1`, you can safely set this to 0.
|
||||
Otherwise, too small values may cause out-of-memory (OOM) errors.
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
max_context_len_to_capture: Maximum context len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode.
|
||||
disable_custom_all_reduce: See ParallelConfig
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
trust_remote_code: bool = False,
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: int = 0,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: bool = False,
|
||||
max_context_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=trust_remote_code,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
dtype=dtype,
|
||||
quantization=quantization,
|
||||
revision=revision,
|
||||
tokenizer_revision=tokenizer_revision,
|
||||
seed=seed,
|
||||
gpu_memory_utilization=gpu_memory_utilization,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
max_context_len_to_capture=max_context_len_to_capture,
|
||||
disable_custom_all_reduce=disable_custom_all_reduce,
|
||||
**kwargs,
|
||||
)
|
||||
self.llm_engine = LLMEngine.from_engine_args(engine_args)
|
||||
self.request_counter = Counter()
|
||||
|
||||
def get_tokenizer(
|
||||
self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
return self.llm_engine.tokenizer.tokenizer
|
||||
|
||||
def set_tokenizer(
|
||||
self,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
) -> None:
|
||||
self.llm_engine.tokenizer.tokenizer = tokenizer
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: Optional[Union[str, List[str]]] = None,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
prompt_token_ids: Optional[List[List[int]]] = None,
|
||||
prefix_pos: Optional[Union[int, List[int]]] = None,
|
||||
use_tqdm: bool = True,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
) -> List[RequestOutput]:
|
||||
"""Generates the completions for the input prompts.
|
||||
|
||||
NOTE: This class automatically batches the given prompts, considering
|
||||
the memory constraint. For the best performance, put all of your prompts
|
||||
into a single list and pass it to this method.
|
||||
|
||||
Args:
|
||||
prompts: A list of prompts to generate completions for.
|
||||
sampling_params: The sampling parameters for text generation. If
|
||||
None, we use the default sampling parameters.
|
||||
prompt_token_ids: A list of token IDs for the prompts. If None, we
|
||||
use the tokenizer to convert the prompts to token IDs.
|
||||
prefix_pos: If not None, we use the given position as the prefix
|
||||
position for each prompt. We will cache the prefix's KV
|
||||
cache and reuse it for the next request with the same prefix.
|
||||
This is an experimental feature, and may be replaced with
|
||||
automatic prefix caching in the future.
|
||||
use_tqdm: Whether to use tqdm to display the progress bar.
|
||||
lora_request: LoRA request to use for generation, if any.
|
||||
|
||||
Returns:
|
||||
A list of `RequestOutput` objects containing the generated
|
||||
completions in the same order as the input prompts.
|
||||
"""
|
||||
if prompts is None and prompt_token_ids is None:
|
||||
raise ValueError("Either prompts or prompt_token_ids must be "
|
||||
"provided.")
|
||||
if isinstance(prompts, str):
|
||||
# Convert a single prompt to a list.
|
||||
prompts = [prompts]
|
||||
if (prompts is not None and prompt_token_ids is not None
|
||||
and len(prompts) != len(prompt_token_ids)):
|
||||
raise ValueError("The lengths of prompts and prompt_token_ids "
|
||||
"must be the same.")
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
sampling_params = SamplingParams()
|
||||
|
||||
# Add requests to the engine.
|
||||
num_requests = len(prompts) if prompts is not None else len(
|
||||
prompt_token_ids)
|
||||
for i in range(num_requests):
|
||||
prompt = prompts[i] if prompts is not None else None
|
||||
prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None
|
||||
token_ids = None if prompt_token_ids is None else prompt_token_ids[
|
||||
i]
|
||||
self._add_request(prompt,
|
||||
sampling_params,
|
||||
token_ids,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos_i)
|
||||
return self._run_engine(use_tqdm)
|
||||
|
||||
def _add_request(
|
||||
self,
|
||||
prompt: Optional[str],
|
||||
sampling_params: SamplingParams,
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prefix_pos: Optional[int] = None,
|
||||
) -> None:
|
||||
request_id = str(next(self.request_counter))
|
||||
self.llm_engine.add_request(request_id,
|
||||
prompt,
|
||||
sampling_params,
|
||||
prompt_token_ids,
|
||||
lora_request=lora_request,
|
||||
prefix_pos=prefix_pos)
|
||||
|
||||
def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]:
|
||||
# Initialize tqdm.
|
||||
if use_tqdm:
|
||||
num_requests = self.llm_engine.get_num_unfinished_requests()
|
||||
pbar = tqdm(total=num_requests, desc="Processed prompts")
|
||||
# Run the engine.
|
||||
outputs: List[RequestOutput] = []
|
||||
while self.llm_engine.has_unfinished_requests():
|
||||
step_outputs = self.llm_engine.step()
|
||||
for output in step_outputs:
|
||||
if output.finished:
|
||||
outputs.append(output)
|
||||
if use_tqdm:
|
||||
pbar.update(1)
|
||||
if use_tqdm:
|
||||
pbar.close()
|
||||
# Sort the outputs by request ID.
|
||||
# This is necessary because some requests may be finished earlier than
|
||||
# its previous requests.
|
||||
outputs = sorted(outputs, key=lambda x: int(x.request_id))
|
||||
return outputs
|
||||
0
vllm/entrypoints/openai/__init__.py
Normal file
0
vllm/entrypoints/openai/__init__.py
Normal file
BIN
vllm/entrypoints/openai/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/api_server.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/api_server.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/protocol.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/protocol.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/entrypoints/openai/__pycache__/serving_chat.cpython-310.pyc
Normal file
BIN
vllm/entrypoints/openai/__pycache__/serving_chat.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
251
vllm/entrypoints/openai/api_server.py
Normal file
251
vllm/entrypoints/openai/api_server.py
Normal file
@@ -0,0 +1,251 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
from contextlib import asynccontextmanager
|
||||
import os
|
||||
import importlib
|
||||
import inspect
|
||||
|
||||
from prometheus_client import make_asgi_app
|
||||
import fastapi
|
||||
import uvicorn
|
||||
from http import HTTPStatus
|
||||
from fastapi import Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
|
||||
from vllm.logger import init_logger
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||
from vllm.entrypoints.openai.serving_engine import LoRA
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
openai_serving_chat: OpenAIServingChat = None
|
||||
openai_serving_completion: OpenAIServingCompletion = None
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: fastapi.FastAPI):
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
await engine.do_log_stats()
|
||||
|
||||
if not engine_args.disable_log_stats:
|
||||
asyncio.create_task(_force_log())
|
||||
|
||||
yield
|
||||
|
||||
|
||||
app = fastapi.FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
class LoRAParserAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
lora_list = []
|
||||
for item in values:
|
||||
name, path = item.split('=')
|
||||
lora_list.append(LoRA(name, path))
|
||||
setattr(namespace, self.dest, lora_list)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="vLLM OpenAI-Compatible RESTful API server.")
|
||||
parser.add_argument("--host", type=str, default=None, help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
||||
parser.add_argument("--allow-credentials",
|
||||
action="store_true",
|
||||
help="allow credentials")
|
||||
parser.add_argument("--allowed-origins",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed origins")
|
||||
parser.add_argument("--allowed-methods",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed methods")
|
||||
parser.add_argument("--allowed-headers",
|
||||
type=json.loads,
|
||||
default=["*"],
|
||||
help="allowed headers")
|
||||
parser.add_argument(
|
||||
"--api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help=
|
||||
"If provided, the server will require this key to be presented in the header."
|
||||
)
|
||||
parser.add_argument("--served-model-name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name used in the API. If not "
|
||||
"specified, the model name will be the same as "
|
||||
"the huggingface name.")
|
||||
parser.add_argument(
|
||||
"--lora-modules",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs='+',
|
||||
action=LoRAParserAction,
|
||||
help=
|
||||
"LoRA module configurations in the format name=path. Multiple modules can be specified."
|
||||
)
|
||||
parser.add_argument("--chat-template",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to the chat template, "
|
||||
"or the template in single-line form "
|
||||
"for the specified model")
|
||||
parser.add_argument("--response-role",
|
||||
type=str,
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=true`.")
|
||||
parser.add_argument("--ssl-keyfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to the SSL key file")
|
||||
parser.add_argument("--ssl-certfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to the SSL cert file")
|
||||
parser.add_argument(
|
||||
"--root-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="FastAPI root_path when app is behind a path based routing proxy")
|
||||
parser.add_argument(
|
||||
"--middleware",
|
||||
type=str,
|
||||
action="append",
|
||||
default=[],
|
||||
help="Additional ASGI middleware to apply to the app. "
|
||||
"We accept multiple --middleware arguments. "
|
||||
"The value should be an import path. "
|
||||
"If a function is provided, vLLM will add it to the server using @app.middleware('http'). "
|
||||
"If a class is provided, vLLM will add it to the server using app.add_middleware(). "
|
||||
)
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
# Add prometheus asgi middleware to route /metrics requests
|
||||
metrics_app = make_asgi_app()
|
||||
app.mount("/metrics", metrics_app)
|
||||
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
err = openai_serving_chat.create_error_response(message=str(exc))
|
||||
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> Response:
|
||||
"""Health check."""
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
async def show_available_models():
|
||||
models = await openai_serving_chat.show_available_models()
|
||||
return JSONResponse(content=models.model_dump())
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
generator = await openai_serving_chat.create_chat_completion(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
if request.stream:
|
||||
return StreamingResponse(content=generator,
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
generator = await openai_serving_completion.create_completion(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
if request.stream:
|
||||
return StreamingResponse(content=generator,
|
||||
media_type="text/event-stream")
|
||||
else:
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=args.allowed_origins,
|
||||
allow_credentials=args.allow_credentials,
|
||||
allow_methods=args.allowed_methods,
|
||||
allow_headers=args.allowed_headers,
|
||||
)
|
||||
|
||||
if token := os.environ.get("VLLM_API_KEY") or args.api_key:
|
||||
|
||||
@app.middleware("http")
|
||||
async def authentication(request: Request, call_next):
|
||||
if not request.url.path.startswith("/v1"):
|
||||
return await call_next(request)
|
||||
if request.headers.get("Authorization") != "Bearer " + token:
|
||||
return JSONResponse(content={"error": "Unauthorized"},
|
||||
status_code=401)
|
||||
return await call_next(request)
|
||||
|
||||
for middleware in args.middleware:
|
||||
module_path, object_name = middleware.rsplit(".", 1)
|
||||
imported = getattr(importlib.import_module(module_path), object_name)
|
||||
if inspect.isclass(imported):
|
||||
app.add_middleware(imported)
|
||||
elif inspect.iscoroutinefunction(imported):
|
||||
app.middleware("http")(imported)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid middleware {middleware}. Must be a function or a class."
|
||||
)
|
||||
|
||||
logger.info(f"args: {args}")
|
||||
|
||||
if args.served_model_name is not None:
|
||||
served_model = args.served_model_name
|
||||
else:
|
||||
served_model = args.model
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
openai_serving_chat = OpenAIServingChat(engine, served_model,
|
||||
args.response_role,
|
||||
args.lora_modules,
|
||||
args.chat_template)
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
engine, served_model, args.lora_modules)
|
||||
|
||||
app.root_path = args.root_path
|
||||
uvicorn.run(app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="info",
|
||||
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
||||
ssl_keyfile=args.ssl_keyfile,
|
||||
ssl_certfile=args.ssl_certfile)
|
||||
323
vllm/entrypoints/openai/protocol.py
Normal file
323
vllm/entrypoints/openai/protocol.py
Normal file
@@ -0,0 +1,323 @@
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
||||
import time
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
object: str = "error"
|
||||
message: str
|
||||
type: str
|
||||
param: Optional[str] = None
|
||||
code: int
|
||||
|
||||
|
||||
class ModelPermission(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
||||
object: str = "model_permission"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
allow_create_engine: bool = False
|
||||
allow_sampling: bool = True
|
||||
allow_logprobs: bool = True
|
||||
allow_search_indices: bool = False
|
||||
allow_view: bool = True
|
||||
allow_fine_tuning: bool = False
|
||||
organization: str = "*"
|
||||
group: Optional[str] = None
|
||||
is_blocking: str = False
|
||||
|
||||
|
||||
class ModelCard(BaseModel):
|
||||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "vllm"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: List[ModelPermission] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ModelList(BaseModel):
|
||||
object: str = "list"
|
||||
data: List[ModelCard] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
completion_tokens: Optional[int] = 0
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
temperature: Optional[float] = 0.7
|
||||
top_p: Optional[float] = 1.0
|
||||
n: Optional[int] = 1
|
||||
max_tokens: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[bool] = False
|
||||
top_logprobs: Optional[int] = None
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
# Additional parameters supported by vLLM
|
||||
best_of: Optional[int] = None
|
||||
top_k: Optional[int] = -1
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
||||
early_stopping: Optional[bool] = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
add_generation_prompt: Optional[bool] = True
|
||||
echo: Optional[bool] = False
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
min_p: Optional[float] = 0.0
|
||||
include_stop_str_in_output: Optional[bool] = False
|
||||
length_penalty: Optional[float] = 1.0
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||
guided_regex: Optional[str] = None
|
||||
guided_choice: Optional[List[str]] = None
|
||||
|
||||
def to_sampling_params(self) -> SamplingParams:
|
||||
if self.logprobs and not self.top_logprobs:
|
||||
raise ValueError("Top logprobs must be set when logprobs is.")
|
||||
|
||||
logits_processors = None
|
||||
if self.logit_bias:
|
||||
|
||||
def logit_bias_logits_processor(
|
||||
token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
for token_id, bias in self.logit_bias.items():
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
bias = min(100, max(-100, bias))
|
||||
logits[int(token_id)] += bias
|
||||
return logits
|
||||
|
||||
logits_processors = [logit_bias_logits_processor]
|
||||
|
||||
return SamplingParams(
|
||||
n=self.n,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
min_p=self.min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
max_tokens=self.max_tokens,
|
||||
logprobs=self.top_logprobs if self.logprobs else None,
|
||||
prompt_logprobs=self.top_logprobs if self.echo else None,
|
||||
best_of=self.best_of,
|
||||
top_k=self.top_k,
|
||||
ignore_eos=self.ignore_eos,
|
||||
use_beam_search=self.use_beam_search,
|
||||
early_stopping=self.early_stopping,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
length_penalty=self.length_penalty,
|
||||
logits_processors=logits_processors,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_guided_decoding_count(cls, data):
|
||||
guide_count = sum([
|
||||
"guided_json" in data and data["guided_json"] is not None,
|
||||
"guided_regex" in data and data["guided_regex"] is not None,
|
||||
"guided_choice" in data and data["guided_choice"] is not None
|
||||
])
|
||||
if guide_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
return data
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
model: str
|
||||
# a string, array of strings, array of tokens, or array of token arrays
|
||||
prompt: Union[List[int], List[List[int]], str, List[str]]
|
||||
suffix: Optional[str] = None
|
||||
max_tokens: Optional[int] = 16
|
||||
temperature: Optional[float] = 1.0
|
||||
top_p: Optional[float] = 1.0
|
||||
n: Optional[int] = 1
|
||||
stream: Optional[bool] = False
|
||||
logprobs: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
||||
seed: Optional[int] = None
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
best_of: Optional[int] = None
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
# Additional parameters supported by vLLM
|
||||
top_k: Optional[int] = -1
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
||||
early_stopping: Optional[bool] = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
min_p: Optional[float] = 0.0
|
||||
include_stop_str_in_output: Optional[bool] = False
|
||||
length_penalty: Optional[float] = 1.0
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = None
|
||||
guided_regex: Optional[str] = None
|
||||
guided_choice: Optional[List[str]] = None
|
||||
|
||||
def to_sampling_params(self):
|
||||
echo_without_generation = self.echo and self.max_tokens == 0
|
||||
|
||||
logits_processors = None
|
||||
if self.logit_bias:
|
||||
|
||||
def logit_bias_logits_processor(
|
||||
token_ids: List[int],
|
||||
logits: torch.Tensor) -> torch.Tensor:
|
||||
for token_id, bias in self.logit_bias.items():
|
||||
# Clamp the bias between -100 and 100 per OpenAI API spec
|
||||
bias = min(100, max(-100, bias))
|
||||
logits[int(token_id)] += bias
|
||||
return logits
|
||||
|
||||
logits_processors = [logit_bias_logits_processor]
|
||||
|
||||
return SamplingParams(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
min_p=self.min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=self.max_tokens if not echo_without_generation else 1,
|
||||
logprobs=self.logprobs,
|
||||
use_beam_search=self.use_beam_search,
|
||||
early_stopping=self.early_stopping,
|
||||
prompt_logprobs=self.logprobs if self.echo else None,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=(self.spaces_between_special_tokens),
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
length_penalty=self.length_penalty,
|
||||
logits_processors=logits_processors,
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_guided_decoding_count(cls, data):
|
||||
guide_count = sum([
|
||||
"guided_json" in data and data["guided_json"] is not None,
|
||||
"guided_regex" in data and data["guided_regex"] is not None,
|
||||
"guided_choice" in data and data["guided_choice"] is not None
|
||||
])
|
||||
if guide_count > 1:
|
||||
raise ValueError(
|
||||
"You can only use one kind of guided decoding "
|
||||
"('guided_json', 'guided_regex' or 'guided_choice').")
|
||||
return data
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
||||
tokens: List[str] = Field(default_factory=list)
|
||||
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
|
||||
|
||||
|
||||
class CompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class CompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class CompletionStreamResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "text_completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[CompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
index: int
|
||||
message: ChatMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: str = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseChoice]
|
||||
usage: UsageInfo
|
||||
|
||||
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
index: int
|
||||
delta: DeltaMessage
|
||||
logprobs: Optional[LogProbs] = None
|
||||
finish_reason: Optional[Literal["stop", "length"]] = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
||||
object: str = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: List[ChatCompletionResponseStreamChoice]
|
||||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
307
vllm/entrypoints/openai/serving_chat.py
Normal file
307
vllm/entrypoints/openai/serving_chat.py
Normal file
@@ -0,0 +1,307 @@
|
||||
import time
|
||||
import codecs
|
||||
from fastapi import Request
|
||||
from typing import AsyncGenerator, AsyncIterator, Optional, List, Union
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionRequest, ChatCompletionResponse,
|
||||
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
||||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
UsageInfo)
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
|
||||
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
engine: AsyncLLMEngine,
|
||||
served_model: str,
|
||||
response_role: str,
|
||||
lora_modules: Optional[List[LoRA]] = None,
|
||||
chat_template=None):
|
||||
super().__init__(engine=engine,
|
||||
served_model=served_model,
|
||||
lora_modules=lora_modules)
|
||||
self.response_role = response_role
|
||||
self._load_chat_template(chat_template)
|
||||
|
||||
async def create_chat_completion(
|
||||
self, request: ChatCompletionRequest, raw_request: Request
|
||||
) -> Union[ErrorResponse, AsyncGenerator[str, None],
|
||||
ChatCompletionResponse]:
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/chat/create
|
||||
for the API specification. This API mimics the OpenAI ChatCompletion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- function_call (Users should implement this by themselves)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
try:
|
||||
prompt = self.tokenizer.apply_chat_template(
|
||||
conversation=request.messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=request.add_generation_prompt)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in applying chat template from request: {str(e)}")
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
try:
|
||||
token_ids = self._validate_prompt_and_tokenize(request,
|
||||
prompt=prompt)
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
guided_decode_logits_processor = (
|
||||
await get_guided_decoding_logits_processor(
|
||||
request, self.engine.get_tokenizer()))
|
||||
if guided_decode_logits_processor:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
sampling_params.logits_processors.append(
|
||||
guided_decode_logits_processor)
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator = self.engine.generate(prompt, sampling_params,
|
||||
request_id, token_ids,
|
||||
lora_request)
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
request, result_generator, request_id)
|
||||
else:
|
||||
return await self.chat_completion_full_generator(
|
||||
request, raw_request, result_generator, request_id)
|
||||
|
||||
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
||||
if request.add_generation_prompt:
|
||||
return self.response_role
|
||||
else:
|
||||
return request.messages[-1]["role"]
|
||||
|
||||
async def chat_completion_stream_generator(
|
||||
self, request: ChatCompletionRequest,
|
||||
result_generator: AsyncIterator[RequestOutput], request_id: str
|
||||
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
|
||||
|
||||
model_name = request.model
|
||||
created_time = int(time.monotonic())
|
||||
chunk_object_type = "chat.completion.chunk"
|
||||
|
||||
# Send first response for each request.n (index) with the role
|
||||
role = self.get_chat_request_role(request)
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role=role),
|
||||
logprobs=None,
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response to echo the input portion of the last message
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if request.messages and isinstance(
|
||||
request.messages, list) and request.messages[-1].get(
|
||||
"content") and request.messages[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
|
||||
if last_msg_content:
|
||||
for i in range(request.n):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=last_msg_content),
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
logprobs=None,
|
||||
model=model_name)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
finish_reason_sent = [False] * request.n
|
||||
async for res in result_generator:
|
||||
res: RequestOutput
|
||||
for output in res.outputs:
|
||||
i = output.index
|
||||
|
||||
if finish_reason_sent[i]:
|
||||
continue
|
||||
|
||||
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
|
||||
top_logprobs = output.logprobs[
|
||||
previous_num_tokens[i]:] if output.logprobs else None
|
||||
|
||||
if request.logprobs:
|
||||
logprobs = self._create_logprobs(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
initial_text_offset=len(previous_texts[i]),
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
if output.finish_reason is None:
|
||||
# Send token-by-token response for each request.n
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=None)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
data = chunk.model_dump_json(exclude_unset=True)
|
||||
yield f"data: {data}\n\n"
|
||||
else:
|
||||
# Send the finish response for each request.n only once
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=previous_num_tokens[i],
|
||||
total_tokens=prompt_tokens + previous_num_tokens[i],
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
model=model_name)
|
||||
if final_usage is not None:
|
||||
chunk.usage = final_usage
|
||||
data = chunk.model_dump_json(exclude_unset=True,
|
||||
exclude_none=True)
|
||||
yield f"data: {data}\n\n"
|
||||
finish_reason_sent[i] = True
|
||||
# Send the final done message after all response.n are finished
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
async def chat_completion_full_generator(
|
||||
self, request: ChatCompletionRequest, raw_request: Request,
|
||||
result_generator: AsyncIterator[RequestOutput],
|
||||
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
|
||||
|
||||
model_name = request.model
|
||||
created_time = int(time.monotonic())
|
||||
final_res: RequestOutput = None
|
||||
|
||||
async for res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(request_id)
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res = res
|
||||
assert final_res is not None
|
||||
|
||||
choices = []
|
||||
|
||||
role = self.get_chat_request_role(request)
|
||||
for output in final_res.outputs:
|
||||
token_ids = output.token_ids
|
||||
top_logprobs = output.logprobs
|
||||
|
||||
if request.logprobs:
|
||||
logprobs = self._create_logprobs(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=output.index,
|
||||
message=ChatMessage(role=role, content=output.text),
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
if request.echo:
|
||||
last_msg_content = ""
|
||||
if request.messages and isinstance(
|
||||
request.messages, list) and request.messages[-1].get(
|
||||
"content") and request.messages[-1].get(
|
||||
"role") == role:
|
||||
last_msg_content = request.messages[-1]["content"]
|
||||
|
||||
for choice in choices:
|
||||
full_message = last_msg_content + choice.message.content
|
||||
choice.message.content = full_message
|
||||
|
||||
num_prompt_tokens = len(final_res.prompt_token_ids)
|
||||
num_generated_tokens = sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
response = ChatCompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _load_chat_template(self, chat_template):
|
||||
if chat_template is not None:
|
||||
try:
|
||||
with open(chat_template, "r") as f:
|
||||
self.tokenizer.chat_template = f.read()
|
||||
except OSError:
|
||||
# If opening a file fails, set chat template to be args to
|
||||
# ensure we decode so our escape are interpreted correctly
|
||||
self.tokenizer.chat_template = codecs.decode(
|
||||
chat_template, "unicode_escape")
|
||||
|
||||
logger.info(
|
||||
f"Using supplied chat template:\n{self.tokenizer.chat_template}"
|
||||
)
|
||||
elif self.tokenizer.chat_template is not None:
|
||||
logger.info(
|
||||
f"Using default chat template:\n{self.tokenizer.chat_template}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"No chat template provided. Chat API will not work.")
|
||||
361
vllm/entrypoints/openai/serving_completion.py
Normal file
361
vllm/entrypoints/openai/serving_completion.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import asyncio
|
||||
import time
|
||||
from fastapi import Request
|
||||
from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict, Tuple
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import random_uuid
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice,
|
||||
CompletionStreamResponse,
|
||||
LogProbs,
|
||||
UsageInfo,
|
||||
)
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA
|
||||
from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TypeTokenIDs = List[int]
|
||||
TypeTopLogProbs = List[Optional[Dict[int, float]]]
|
||||
TypeCreateLogProbsFn = Callable[
|
||||
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs]
|
||||
|
||||
|
||||
async def completion_stream_generator(
|
||||
request: CompletionRequest,
|
||||
raw_request: Request,
|
||||
on_abort,
|
||||
result_generator: AsyncIterator[Tuple[int, RequestOutput]],
|
||||
create_logprobs_fn: TypeCreateLogProbsFn,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
num_prompts: int,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
previous_texts = [""] * request.n * num_prompts
|
||||
previous_num_tokens = [0] * request.n * num_prompts
|
||||
has_echoed = [False] * request.n * num_prompts
|
||||
|
||||
async for prompt_idx, res in result_generator:
|
||||
|
||||
# Abort the request if the client disconnects.
|
||||
if await raw_request.is_disconnected():
|
||||
await on_abort(f"{request_id}-{prompt_idx}")
|
||||
raise StopAsyncIteration()
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * request.n
|
||||
# TODO(simon): optimize the performance by avoiding full text O(n^2) sending.
|
||||
|
||||
if request.echo and request.max_tokens == 0:
|
||||
# only return the prompt
|
||||
delta_text = res.prompt
|
||||
delta_token_ids = res.prompt_token_ids
|
||||
top_logprobs = res.prompt_logprobs
|
||||
has_echoed[i] = True
|
||||
elif request.echo and request.max_tokens > 0 and not has_echoed[i]:
|
||||
# echo the prompt and first token
|
||||
delta_text = res.prompt + output.text
|
||||
delta_token_ids = res.prompt_token_ids + output.token_ids
|
||||
top_logprobs = res.prompt_logprobs + (output.logprobs or [])
|
||||
has_echoed[i] = True
|
||||
else:
|
||||
# return just the delta
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
delta_token_ids = output.token_ids[previous_num_tokens[i]:]
|
||||
top_logprobs = output.logprobs[
|
||||
previous_num_tokens[i]:] if output.logprobs else None
|
||||
|
||||
if request.logprobs is not None:
|
||||
assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested"
|
||||
logprobs = create_logprobs_fn(
|
||||
token_ids=delta_token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
initial_text_offset=len(previous_texts[i]),
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
previous_texts[i] = output.text
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
finish_reason = output.finish_reason
|
||||
response_json = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
index=i,
|
||||
text=delta_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
]).model_dump_json()
|
||||
yield f"data: {response_json}\n\n"
|
||||
|
||||
if output.finish_reason is not None: # return final usage
|
||||
logprobs = LogProbs() if request.logprobs is not None else None
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
completion_tokens = len(output.token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
response_json = CompletionStreamResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=[
|
||||
CompletionResponseStreamChoice(
|
||||
index=i,
|
||||
text="",
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
],
|
||||
usage=final_usage,
|
||||
).model_dump_json()
|
||||
yield f"data: {response_json}\n\n"
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
def parse_prompt_format(prompt) -> Tuple[bool, list]:
|
||||
# get the prompt, openai supports the following
|
||||
# "a string, array of strings, array of tokens, or array of token arrays."
|
||||
prompt_is_tokens = False
|
||||
prompts = [prompt] # case 1: a string
|
||||
if isinstance(prompt, list):
|
||||
if len(prompt) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
elif isinstance(prompt[0], str):
|
||||
prompt_is_tokens = False
|
||||
prompts = prompt # case 2: array of strings
|
||||
elif isinstance(prompt[0], int):
|
||||
prompt_is_tokens = True
|
||||
prompts = [prompt] # case 3: array of tokens
|
||||
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
|
||||
prompt_is_tokens = True
|
||||
prompts = prompt # case 4: array of token arrays
|
||||
else:
|
||||
raise ValueError(
|
||||
"prompt must be a string, array of strings, array of tokens, or array of token arrays"
|
||||
)
|
||||
return prompt_is_tokens, prompts
|
||||
|
||||
|
||||
def request_output_to_completion_response(
|
||||
final_res_batch: List[RequestOutput],
|
||||
request: CompletionRequest,
|
||||
create_logprobs_fn: TypeCreateLogProbsFn,
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
model_name: str,
|
||||
) -> CompletionResponse:
|
||||
choices = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
for final_res in final_res_batch:
|
||||
assert final_res is not None
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
prompt_logprobs = final_res.prompt_logprobs
|
||||
prompt_text = final_res.prompt
|
||||
|
||||
for output in final_res.outputs:
|
||||
if request.echo and request.max_tokens == 0:
|
||||
token_ids = prompt_token_ids
|
||||
top_logprobs = prompt_logprobs
|
||||
output_text = prompt_text
|
||||
elif request.echo and request.max_tokens > 0:
|
||||
token_ids = prompt_token_ids + output.token_ids
|
||||
top_logprobs = prompt_logprobs + output.logprobs
|
||||
output_text = prompt_text + output.text
|
||||
else:
|
||||
token_ids = output.token_ids
|
||||
top_logprobs = output.logprobs
|
||||
output_text = output.text
|
||||
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs_fn(
|
||||
token_ids=token_ids,
|
||||
top_logprobs=top_logprobs,
|
||||
num_output_top_logprobs=request.logprobs,
|
||||
)
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
choice_data = CompletionResponseChoice(
|
||||
index=len(choices),
|
||||
text=output_text,
|
||||
logprobs=logprobs,
|
||||
finish_reason=output.finish_reason,
|
||||
)
|
||||
choices.append(choice_data)
|
||||
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
num_generated_tokens += sum(
|
||||
len(output.token_ids) for output in final_res.outputs)
|
||||
|
||||
usage = UsageInfo(
|
||||
prompt_tokens=num_prompt_tokens,
|
||||
completion_tokens=num_generated_tokens,
|
||||
total_tokens=num_prompt_tokens + num_generated_tokens,
|
||||
)
|
||||
|
||||
return CompletionResponse(
|
||||
id=request_id,
|
||||
created=created_time,
|
||||
model=model_name,
|
||||
choices=choices,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def merge_async_iterators(*iterators):
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
|
||||
This method handle the case where some iterators finish before others.
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
"""
|
||||
queue = asyncio.Queue()
|
||||
|
||||
finished = [False] * len(iterators)
|
||||
|
||||
async def producer(i, iterator):
|
||||
async for item in iterator:
|
||||
await queue.put((i, item))
|
||||
finished[i] = True
|
||||
|
||||
_tasks = [
|
||||
asyncio.create_task(producer(i, iterator))
|
||||
for i, iterator in enumerate(iterators)
|
||||
]
|
||||
|
||||
async def consumer():
|
||||
while not all(finished) or not queue.empty():
|
||||
item = await queue.get()
|
||||
yield item
|
||||
await asyncio.gather(*_tasks)
|
||||
|
||||
return consumer()
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
engine: AsyncLLMEngine,
|
||||
served_model: str,
|
||||
lora_modules: Optional[List[LoRA]] = None):
|
||||
super().__init__(engine=engine,
|
||||
served_model=served_model,
|
||||
lora_modules=lora_modules)
|
||||
|
||||
async def create_completion(self, request: CompletionRequest,
|
||||
raw_request: Request):
|
||||
"""Completion API similar to OpenAI's API.
|
||||
|
||||
See https://platform.openai.com/docs/api-reference/completions/create
|
||||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following feature:
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
"""
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
# Return error for unsupported features.
|
||||
if request.suffix is not None:
|
||||
return self.create_error_response(
|
||||
"suffix is not currently supported")
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
created_time = int(time.monotonic())
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators = []
|
||||
try:
|
||||
sampling_params = request.to_sampling_params()
|
||||
lora_request = self._maybe_get_lora(request)
|
||||
guided_decode_logit_processor = (
|
||||
await get_guided_decoding_logits_processor(
|
||||
request, self.engine.get_tokenizer()))
|
||||
if guided_decode_logit_processor is not None:
|
||||
if sampling_params.logits_processors is None:
|
||||
sampling_params.logits_processors = []
|
||||
sampling_params.logits_processors.append(
|
||||
guided_decode_logit_processor)
|
||||
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
if prompt_is_tokens:
|
||||
input_ids = self._validate_prompt_and_tokenize(
|
||||
request, prompt_ids=prompt)
|
||||
else:
|
||||
input_ids = self._validate_prompt_and_tokenize(
|
||||
request, prompt=prompt)
|
||||
|
||||
generators.append(
|
||||
self.engine.generate(prompt,
|
||||
sampling_params,
|
||||
f"{request_id}-{i}",
|
||||
prompt_token_ids=input_ids,
|
||||
lora_request=lora_request))
|
||||
except ValueError as e:
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
result_generator: AsyncIterator[Tuple[
|
||||
int, RequestOutput]] = merge_async_iterators(*generators)
|
||||
|
||||
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
||||
# results. In addition, we do not stream the results when use beam search.
|
||||
stream = (request.stream
|
||||
and (request.best_of is None or request.n == request.best_of)
|
||||
and not request.use_beam_search)
|
||||
|
||||
# Streaming response
|
||||
if stream:
|
||||
return completion_stream_generator(request,
|
||||
raw_request,
|
||||
self.engine.abort,
|
||||
result_generator,
|
||||
self._create_logprobs,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
num_prompts=len(prompts))
|
||||
|
||||
# Non-streaming response
|
||||
final_res_batch: RequestOutput = [None] * len(prompts)
|
||||
async for i, res in result_generator:
|
||||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(f"{request_id}-{i}")
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res_batch[i] = res
|
||||
response = request_output_to_completion_response(
|
||||
final_res_batch, request, self._create_logprobs, request_id,
|
||||
created_time, model_name)
|
||||
|
||||
# When user requests streaming but we don't stream, we still need to
|
||||
# return a streaming response with a single event.
|
||||
if request.stream:
|
||||
response_json = response.model_dump_json()
|
||||
|
||||
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
||||
yield f"data: {response_json}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return fake_stream_generator()
|
||||
|
||||
return response
|
||||
172
vllm/entrypoints/openai/serving_engine.py
Normal file
172
vllm/entrypoints/openai/serving_engine.py
Normal file
@@ -0,0 +1,172 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.openai.protocol import (CompletionRequest,
|
||||
ChatCompletionRequest,
|
||||
ErrorResponse, LogProbs,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission)
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRA:
|
||||
name: str
|
||||
local_path: str
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
|
||||
def __init__(self,
|
||||
engine: AsyncLLMEngine,
|
||||
served_model: str,
|
||||
lora_modules=Optional[List[LoRA]]):
|
||||
self.engine = engine
|
||||
self.served_model = served_model
|
||||
if lora_modules is None:
|
||||
self.lora_requests = []
|
||||
else:
|
||||
self.lora_requests = [
|
||||
LoRARequest(
|
||||
lora_name=lora.name,
|
||||
lora_int_id=i,
|
||||
lora_local_path=lora.local_path,
|
||||
) for i, lora in enumerate(lora_modules, start=1)
|
||||
]
|
||||
|
||||
self.max_model_len = 0
|
||||
self.tokenizer = None
|
||||
|
||||
try:
|
||||
event_loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
event_loop = None
|
||||
|
||||
if event_loop is not None and event_loop.is_running(
|
||||
): # If the current is instanced by Ray Serve, there is already a running event loop
|
||||
event_loop.create_task(self._post_init())
|
||||
else: # When using single vLLM without engine_use_ray
|
||||
asyncio.run(self._post_init())
|
||||
|
||||
async def _post_init(self):
|
||||
engine_model_config = await self.engine.get_model_config()
|
||||
self.max_model_len = engine_model_config.max_model_len
|
||||
|
||||
# A separate tokenizer to map token IDs to strings.
|
||||
self.tokenizer = get_tokenizer(
|
||||
engine_model_config.tokenizer,
|
||||
tokenizer_mode=engine_model_config.tokenizer_mode,
|
||||
trust_remote_code=engine_model_config.trust_remote_code)
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [
|
||||
ModelCard(id=self.served_model,
|
||||
root=self.served_model,
|
||||
permission=[ModelPermission()])
|
||||
]
|
||||
lora_cards = [
|
||||
ModelCard(id=lora.lora_name,
|
||||
root=self.served_model,
|
||||
permission=[ModelPermission()])
|
||||
for lora in self.lora_requests
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
return ModelList(data=model_cards)
|
||||
|
||||
def _create_logprobs(
|
||||
self,
|
||||
token_ids: List[int],
|
||||
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
|
||||
num_output_top_logprobs: Optional[int] = None,
|
||||
initial_text_offset: int = 0,
|
||||
) -> LogProbs:
|
||||
"""Create OpenAI-style logprobs."""
|
||||
logprobs = LogProbs()
|
||||
last_token_len = 0
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs = []
|
||||
for i, token_id in enumerate(token_ids):
|
||||
step_top_logprobs = top_logprobs[i]
|
||||
if step_top_logprobs is not None:
|
||||
token_logprob = step_top_logprobs[token_id]
|
||||
else:
|
||||
token_logprob = None
|
||||
token = self.tokenizer.convert_ids_to_tokens(token_id)
|
||||
logprobs.tokens.append(token)
|
||||
logprobs.token_logprobs.append(token_logprob)
|
||||
if len(logprobs.text_offset) == 0:
|
||||
logprobs.text_offset.append(initial_text_offset)
|
||||
else:
|
||||
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
||||
last_token_len)
|
||||
last_token_len = len(token)
|
||||
|
||||
if num_output_top_logprobs:
|
||||
logprobs.top_logprobs.append({
|
||||
self.tokenizer.convert_ids_to_tokens(i): p
|
||||
for i, p in step_top_logprobs.items()
|
||||
} if step_top_logprobs else None)
|
||||
return logprobs
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str,
|
||||
err_type: str = "BadRequestError",
|
||||
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
||||
return ErrorResponse(message=message,
|
||||
type=err_type,
|
||||
code=status_code.value)
|
||||
|
||||
async def _check_model(self, request) -> Optional[ErrorResponse]:
|
||||
if request.model == self.served_model:
|
||||
return
|
||||
if request.model in [lora.lora_name for lora in self.lora_requests]:
|
||||
return
|
||||
return self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
|
||||
if request.model == self.served_model:
|
||||
return
|
||||
for lora in self.lora_requests:
|
||||
if request.model == lora.lora_name:
|
||||
return lora
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError("The model `{request.model}` does not exist.")
|
||||
|
||||
def _validate_prompt_and_tokenize(
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None) -> List[int]:
|
||||
if not (prompt or prompt_ids):
|
||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||
if (prompt and prompt_ids):
|
||||
raise ValueError(
|
||||
"Only one of prompt or prompt_ids should be provided.")
|
||||
|
||||
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
|
||||
prompt).input_ids
|
||||
token_num = len(input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
request.max_tokens = self.max_model_len - token_num
|
||||
|
||||
if token_num + request.max_tokens > self.max_model_len:
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is {self.max_model_len} tokens. "
|
||||
f"However, you requested {request.max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{request.max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.", )
|
||||
else:
|
||||
return input_ids
|
||||
61
vllm/logger.py
Normal file
61
vllm/logger.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Adapted from
|
||||
# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
|
||||
"""Logging configuration for vLLM."""
|
||||
import logging
|
||||
import sys
|
||||
import os
|
||||
|
||||
VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1"))
|
||||
|
||||
_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
|
||||
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
||||
|
||||
|
||||
class NewLineFormatter(logging.Formatter):
|
||||
"""Adds logging prefix to newlines to align multi-line messages."""
|
||||
|
||||
def __init__(self, fmt, datefmt=None):
|
||||
logging.Formatter.__init__(self, fmt, datefmt)
|
||||
|
||||
def format(self, record):
|
||||
msg = logging.Formatter.format(self, record)
|
||||
if record.message != "":
|
||||
parts = msg.split(record.message)
|
||||
msg = msg.replace("\n", "\r\n" + parts[0])
|
||||
return msg
|
||||
|
||||
|
||||
_root_logger = logging.getLogger("vllm")
|
||||
_default_handler = None
|
||||
|
||||
|
||||
def _setup_logger():
|
||||
_root_logger.setLevel(logging.DEBUG)
|
||||
global _default_handler
|
||||
if _default_handler is None:
|
||||
_default_handler = logging.StreamHandler(sys.stdout)
|
||||
_default_handler.flush = sys.stdout.flush # type: ignore
|
||||
_default_handler.setLevel(logging.INFO)
|
||||
_root_logger.addHandler(_default_handler)
|
||||
fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)
|
||||
_default_handler.setFormatter(fmt)
|
||||
# Setting this will avoid the message
|
||||
# being propagated to the parent logger.
|
||||
_root_logger.propagate = False
|
||||
|
||||
|
||||
# The logger is initialized when the module is imported.
|
||||
# This is thread-safe as the module is only imported once,
|
||||
# guaranteed by the Python GIL.
|
||||
if VLLM_CONFIGURE_LOGGING:
|
||||
_setup_logger()
|
||||
|
||||
|
||||
def init_logger(name: str):
|
||||
# Use the same settings as above for root logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG"))
|
||||
if VLLM_CONFIGURE_LOGGING:
|
||||
logger.addHandler(_default_handler)
|
||||
logger.propagate = False
|
||||
return logger
|
||||
0
vllm/lora/__init__.py
Normal file
0
vllm/lora/__init__.py
Normal file
BIN
vllm/lora/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/lora/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/lora/__pycache__/layers.cpython-310.pyc
Normal file
BIN
vllm/lora/__pycache__/layers.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/lora/__pycache__/lora.cpython-310.pyc
Normal file
BIN
vllm/lora/__pycache__/lora.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/lora/__pycache__/models.cpython-310.pyc
Normal file
BIN
vllm/lora/__pycache__/models.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/lora/__pycache__/punica.cpython-310.pyc
Normal file
BIN
vllm/lora/__pycache__/punica.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/lora/__pycache__/request.cpython-310.pyc
Normal file
BIN
vllm/lora/__pycache__/request.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/lora/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm/lora/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/lora/__pycache__/worker_manager.cpython-310.pyc
Normal file
BIN
vllm/lora/__pycache__/worker_manager.cpython-310.pyc
Normal file
Binary file not shown.
979
vllm/lora/layers.py
Normal file
979
vllm/lora/layers.py
Normal file
@@ -0,0 +1,979 @@
|
||||
# pylint: disable=unused-argument
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_gather,
|
||||
)
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
QKVParallelLinear,
|
||||
MergedColumnParallelLinear)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
def _apply_lora(
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
"""Applies lora to each input.
|
||||
|
||||
This method applies all loras to each input. It uses the
|
||||
indices vector to determine which lora yields the
|
||||
correct output. An index of -1 means no lora should be
|
||||
applied. This method adds the final lora results to the
|
||||
output.
|
||||
|
||||
Input shapes:
|
||||
x: (batch_size, hidden_dim)
|
||||
lora_a_stacked: (num_loras, lora_rank, hidden_dim)
|
||||
lora_b_stacked: (num_loras, output_dim, lora_rank)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, output_dim)
|
||||
"""
|
||||
org_output = output
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
|
||||
return output.view_as(org_output)
|
||||
|
||||
|
||||
def _apply_lora_packed_nslice(
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
output_slices: Tuple[int, ...],
|
||||
):
|
||||
"""Applies lora to each input.
|
||||
|
||||
This method applies all loras to each input. It uses the
|
||||
indices vector to determine which lora yields the
|
||||
correct output. An index of -1 means no lora should be
|
||||
applied. This method adds the final lora results to the
|
||||
output.
|
||||
|
||||
This method is used for layers that are composed of multiple sublayers
|
||||
(slices) packed together.
|
||||
|
||||
Input shapes:
|
||||
x: (batch_size, hidden_dim)
|
||||
lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
|
||||
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
||||
output_slices: n-1 element tuple of (slice_size...), where n is number of slices
|
||||
"""
|
||||
org_output = output
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
offset_left = 0
|
||||
for slice_idx in range(len(output_slices)):
|
||||
add_lora_slice(output, x, lora_a_stacked[slice_idx],
|
||||
lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
|
||||
output_slices[slice_idx])
|
||||
offset_left += output_slices[slice_idx]
|
||||
return output.view_as(org_output)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAMapping:
|
||||
# Per every token in input_ids:
|
||||
index_mapping: Tuple[int, ...]
|
||||
# Per sampled token:
|
||||
prompt_mapping: Tuple[int, ...]
|
||||
|
||||
def __post_init__(self):
|
||||
self.index_mapping = tuple(self.index_mapping)
|
||||
self.prompt_mapping = tuple(self.prompt_mapping)
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
|
||||
def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig,
|
||||
model_config: PretrainedConfig) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
...
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
"""Resets the lora weights at index back to 0."""
|
||||
...
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
"""Overwrites lora tensors at index."""
|
||||
...
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
"""Sets the mapping indices."""
|
||||
...
|
||||
|
||||
|
||||
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def __init__(self, base_layer: VocabParallelEmbedding) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
|
||||
lora_vocab_start_idx = self.base_layer.org_vocab_size
|
||||
weights_idx = None
|
||||
if self.base_layer.vocab_end_index > lora_vocab_start_idx:
|
||||
# We can start adding lora weights
|
||||
weights_idx = max(
|
||||
lora_vocab_start_idx - self.base_layer.vocab_start_index, 0)
|
||||
self.embeddings_slice = (self.base_layer.vocab_start_index -
|
||||
self.base_layer.org_vocab_size +
|
||||
weights_idx,
|
||||
self.base_layer.vocab_end_index -
|
||||
self.base_layer.org_vocab_size)
|
||||
self.embeddings_weights = self.base_layer.weight.data[weights_idx:]
|
||||
self.embeddings_weights.fill_(0)
|
||||
else:
|
||||
self.embeddings_slice = None
|
||||
self.embeddings_weights = None
|
||||
|
||||
self.embeddings_tensors = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
self.base_layer.embedding_dim,
|
||||
),
|
||||
dtype=self.base_layer.weight.dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
self.base_layer.org_vocab_size +
|
||||
lora_config.lora_extra_vocab_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.embedding_dim,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_a_stacked_2d = self.lora_a_stacked.view(
|
||||
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
|
||||
self.lora_a_stacked.shape[2],
|
||||
)
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
self.embeddings_indices = None
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
self.embeddings_tensors[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_(
|
||||
lora_a, non_blocking=True)
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
if embeddings_tensor is not None:
|
||||
self.embeddings_tensors[
|
||||
index, :embeddings_tensor.shape[0], :embeddings_tensor.
|
||||
shape[1]].copy_(embeddings_tensor, non_blocking=True)
|
||||
if self.embeddings_slice is not None:
|
||||
# TODO(yard1): Optimize this copy, we don't need to copy
|
||||
# everything, just the modified part
|
||||
embeddings = self.embeddings_tensors.view(
|
||||
self.embeddings_tensors.shape[0] *
|
||||
self.embeddings_tensors.shape[1],
|
||||
self.embeddings_tensors.shape[2]
|
||||
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
|
||||
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.indices = base_indices
|
||||
self.embeddings_indices = embeddings_indices
|
||||
self.indices_len = indices_len
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
|
||||
indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x)
|
||||
full_lora_a_embeddings = F.embedding(
|
||||
x + indices,
|
||||
self.lora_a_stacked_2d,
|
||||
)
|
||||
indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x)
|
||||
full_output = self.base_layer.forward(
|
||||
x.add_(indices * added_tokens_mask))
|
||||
|
||||
full_output_org = full_output
|
||||
if full_output.ndim == 3:
|
||||
full_output = full_output.view(
|
||||
full_output.shape[0] * full_output.shape[1], -1)
|
||||
if full_lora_a_embeddings.ndim == 3:
|
||||
full_lora_a_embeddings = full_lora_a_embeddings.view(
|
||||
full_lora_a_embeddings.shape[0] *
|
||||
full_lora_a_embeddings.shape[1], -1)
|
||||
bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
return full_output.view_as(full_output_org)
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def __init__(self, base_layer: ColumnParallelLinear) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.weight.shape[0],
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
self.output_dim = self.lora_b_stacked.shape[1]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.indices = base_indices
|
||||
self.indices_len = indices_len
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x, bias)
|
||||
_apply_lora(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]],
|
||||
output,
|
||||
)
|
||||
return output
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of ColumnParallelLinear
|
||||
|
||||
Args:
|
||||
input_: Tensor whose last dimension is `input_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
bias = (self.base_layer.bias
|
||||
if not self.base_layer.skip_bias_add else None)
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply_weights(input_, bias)
|
||||
if self.base_layer.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = (self.base_layer.bias
|
||||
if self.base_layer.skip_bias_add else None)
|
||||
return output, output_bias
|
||||
|
||||
@property
|
||||
def linear_weights(self):
|
||||
return self.base_layer.linear_weights
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
"""ColumnParallelLinear layer that is composed of 2 sublayers (slices)
|
||||
packed together (eg. gate_proj + up_proj -> gate_up_proj).
|
||||
|
||||
This means we have 2 LoRAs, each applied to one half of the layer.
|
||||
|
||||
Both slices must have the same size.
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: MergedColumnParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
n_slices = 2
|
||||
if not (len(self.base_layer.output_sizes) == n_slices
|
||||
and self.base_layer.output_sizes[0]
|
||||
== self.base_layer.output_sizes[1]):
|
||||
raise ValueError(
|
||||
"LoRAColumnParallelLinear2Slice requires 2 slices with "
|
||||
"the same size.")
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
|
||||
self.lora_a_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
) for _ in range(n_slices))
|
||||
self.lora_b_stacked = tuple(
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.weight.shape[0] // 2,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
) for _ in range(n_slices))
|
||||
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.output_dim = self.lora_b_stacked[0].shape[2]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[0][index] = 0
|
||||
self.lora_a_stacked[1][index] = 0
|
||||
self.lora_b_stacked[0][index] = 0
|
||||
self.lora_b_stacked[1][index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.output_dim
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_b = lora_b[0][:,
|
||||
start_idx:end_idx], lora_b[1][:,
|
||||
start_idx:end_idx]
|
||||
|
||||
if lora_a[0] is not None:
|
||||
self.lora_a_stacked[0][
|
||||
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
|
||||
lora_a[0].T, non_blocking=True)
|
||||
self.lora_b_stacked[0][
|
||||
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
||||
lora_b[0].T, non_blocking=True)
|
||||
if lora_a[1] is not None:
|
||||
self.lora_a_stacked[1][
|
||||
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
||||
lora_a[1].T, non_blocking=True)
|
||||
self.lora_b_stacked[1][
|
||||
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
||||
lora_b[1].T, non_blocking=True)
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x, bias)
|
||||
_apply_lora_packed_nslice(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]],
|
||||
output,
|
||||
(self.output_dim, self.output_dim),
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
"""ColumnParallelLinear layer that is composed of 3 sublayers (slices)
|
||||
packed together in qkv proj fashion
|
||||
(q_proj + k_proj + v_proj -> qkv_proj).
|
||||
|
||||
This means we have 3 LoRAs, each applied to one slice of the layer.
|
||||
|
||||
Q slice may have different shape than K and V slices (which both have
|
||||
the same shape).
|
||||
"""
|
||||
|
||||
def __init__(self, base_layer: QKVParallelLinear) -> None:
|
||||
super().__init__(base_layer)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
self.q_proj_shard_size = (self.base_layer.num_heads *
|
||||
self.base_layer.head_size)
|
||||
self.kv_proj_shard_size = (self.base_layer.num_kv_heads *
|
||||
self.base_layer.head_size)
|
||||
self.q_shard_id = tp_rank
|
||||
self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas
|
||||
|
||||
# q, k, v
|
||||
self.lora_a_stacked = (
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
),
|
||||
)
|
||||
self.lora_b_stacked = (
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.q_proj_shard_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.kv_proj_shard_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
),
|
||||
torch.zeros(
|
||||
max_loras,
|
||||
1,
|
||||
self.kv_proj_shard_size,
|
||||
lora_config.max_lora_rank,
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
),
|
||||
)
|
||||
|
||||
self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
|
||||
self.kv_proj_shard_size)
|
||||
self.packed_indices: Optional[torch.Tensor] = None
|
||||
self.standard_indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[0][index] = 0
|
||||
self.lora_b_stacked[0][index] = 0
|
||||
self.lora_a_stacked[1][index] = 0
|
||||
self.lora_b_stacked[1][index] = 0
|
||||
self.lora_a_stacked[2][index] = 0
|
||||
self.lora_b_stacked[2][index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
|
||||
if self.tp_size > 1:
|
||||
if lora_b[0] is not None:
|
||||
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1)]
|
||||
self.lora_b_stacked[0][
|
||||
index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_(
|
||||
lora_b_q.T, non_blocking=True)
|
||||
if lora_b[1] is not None:
|
||||
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
self.lora_b_stacked[1][
|
||||
index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_(
|
||||
lora_b_k.T, non_blocking=True)
|
||||
if lora_b[2] is not None:
|
||||
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
self.lora_b_stacked[2][
|
||||
index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_(
|
||||
lora_b_v.T, non_blocking=True)
|
||||
else:
|
||||
if lora_b[0] is not None:
|
||||
self.lora_b_stacked[0][
|
||||
index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_(
|
||||
lora_b[0].T, non_blocking=True)
|
||||
if lora_b[1] is not None:
|
||||
self.lora_b_stacked[1][
|
||||
index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_(
|
||||
lora_b[1].T, non_blocking=True)
|
||||
if lora_b[2] is not None:
|
||||
self.lora_b_stacked[2][
|
||||
index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_(
|
||||
lora_b[2].T, non_blocking=True)
|
||||
|
||||
if lora_a[0] is not None:
|
||||
self.lora_a_stacked[0][
|
||||
index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_(
|
||||
lora_a[0].T, non_blocking=True)
|
||||
if lora_a[1] is not None:
|
||||
self.lora_a_stacked[1][
|
||||
index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_(
|
||||
lora_a[1].T, non_blocking=True)
|
||||
if lora_a[2] is not None:
|
||||
self.lora_a_stacked[2][
|
||||
index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_(
|
||||
lora_a[2].T, non_blocking=True)
|
||||
|
||||
def apply_weights(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x, bias)
|
||||
_apply_lora_packed_nslice(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]],
|
||||
output,
|
||||
self.output_slices,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def __init__(self, base_layer: RowParallelLinear) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.base_layer.weight.shape[1],
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
self.base_layer.weight.shape[0],
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.base_layer.weight.device,
|
||||
)
|
||||
self.indices: Optional[torch.Tensor] = None
|
||||
self.indices_len: Optional[List[int]] = None
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
if self.base_layer.tp_size > 1:
|
||||
tensor_model_parallel_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = self.base_layer.weight.shape[1]
|
||||
start_idx = tensor_model_parallel_rank * shard_size
|
||||
end_idx = (tensor_model_parallel_rank + 1) * shard_size
|
||||
lora_a = lora_a[start_idx:end_idx, :]
|
||||
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.indices = base_indices
|
||||
self.indices_len = indices_len
|
||||
|
||||
def apply_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.linear_method.apply_weights(
|
||||
self.base_layer.linear_weights, x)
|
||||
_apply_lora(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]],
|
||||
output,
|
||||
)
|
||||
return output
|
||||
|
||||
def forward(self, input_):
|
||||
"""Forward of RowParallelLinear
|
||||
|
||||
Args:
|
||||
input_: tensor whose last dimension is `input_size`. If
|
||||
`input_is_parallel` is set, then the last dimension
|
||||
is `input_size // tp_size`.
|
||||
|
||||
Returns:
|
||||
- output
|
||||
- bias
|
||||
"""
|
||||
# Set up backprop all-reduce.
|
||||
if self.base_layer.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
# TODO: simplify code below
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
splitted_input = split_tensor_along_last_dim(
|
||||
input_, num_partitions=self.base_layer.tp_size)
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.apply_weights(input_parallel)
|
||||
if self.base_layer.reduce_results and self.base_layer.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
output_ = output_parallel
|
||||
|
||||
if not self.base_layer.skip_bias_add:
|
||||
output = (output_ + self.base_layer.bias
|
||||
if self.base_layer.bias is not None else output_)
|
||||
output_bias = None
|
||||
else:
|
||||
output = output_
|
||||
output_bias = self.base_layer.bias
|
||||
return output, output_bias
|
||||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.base_layer.weight
|
||||
|
||||
|
||||
class SamplerWithLoRA(BaseLayerWithLoRA):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: Sampler,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
self.hidden_size = hidden_size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
|
||||
@property
|
||||
def logits_as_hidden_states(self):
|
||||
return self.base_layer.logits_as_hidden_states
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
return self.base_layer.vocab_size
|
||||
|
||||
@property
|
||||
def org_vocab_size(self):
|
||||
return self.base_layer.org_vocab_size
|
||||
|
||||
@property
|
||||
def include_gpu_probs_tensor(self):
|
||||
return self.base_layer.include_gpu_probs_tensor
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
# Keep this in sync with csrc/punica/bgmv/bgmv_config.h
|
||||
if 32000 < self.base_layer.vocab_size > 33024:
|
||||
raise ValueError(
|
||||
"When using LoRA, vocab size must be 32000 >= vocab_size <= 33024"
|
||||
)
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
lora_config.max_lora_rank,
|
||||
self.hidden_size,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.lora_b_stacked = torch.zeros(
|
||||
(
|
||||
max_loras,
|
||||
1,
|
||||
# Pad for kernel compatibility
|
||||
math.ceil(self.base_layer.vocab_size /
|
||||
lora_config.lora_vocab_padding_size) *
|
||||
lora_config.lora_vocab_padding_size,
|
||||
lora_config.max_lora_rank,
|
||||
),
|
||||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.embeddings_tensors = torch.full(
|
||||
(max_loras, lora_config.lora_extra_vocab_size, self.hidden_size),
|
||||
fill_value=float("-inf"),
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.indices = None
|
||||
self.indices_padded = None
|
||||
self.indices_len = None
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
self.embeddings_tensors[index] = float("-inf")
|
||||
|
||||
def set_lora(
|
||||
self,
|
||||
index: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor],
|
||||
):
|
||||
self.reset_lora(index)
|
||||
self.lora_a_stacked[index,
|
||||
0, :lora_a.shape[1], :lora_a.shape[0]].copy_(
|
||||
lora_a.T, non_blocking=True)
|
||||
self.lora_b_stacked[index,
|
||||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
if embeddings_tensor is not None:
|
||||
self.embeddings_tensors[
|
||||
index, :embeddings_tensor.shape[0], :embeddings_tensor.
|
||||
shape[1], ] = embeddings_tensor
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.indices = sampler_indices
|
||||
self.indices_padded = sampler_indices_padded
|
||||
self.indices_len = indices_len
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
embedding_bias: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
# Get the logits for the next tokens.
|
||||
logits = torch.matmul(hidden_states, embedding.t())
|
||||
if embedding_bias is not None:
|
||||
logits += embedding_bias
|
||||
logits = tensor_model_parallel_gather(logits)
|
||||
if logits is None:
|
||||
return None
|
||||
|
||||
lora_logits = torch.empty(
|
||||
self.embeddings_tensors.shape[0] + 1,
|
||||
self.embeddings_tensors.shape[1],
|
||||
hidden_states.shape[0],
|
||||
dtype=self.embeddings_tensors.dtype,
|
||||
device=self.embeddings_tensors.device,
|
||||
)
|
||||
torch.matmul(self.embeddings_tensors,
|
||||
hidden_states.T,
|
||||
out=lora_logits[:-1])
|
||||
lora_logits[-1] = float("-inf")
|
||||
lora_logits = lora_logits.mT
|
||||
lora_logits = (lora_logits.reshape(
|
||||
lora_logits.shape[0] * lora_logits.shape[1],
|
||||
lora_logits.shape[2],
|
||||
).index_select(0,
|
||||
self.indices_padded[:self.indices_len[2]]).nan_to_num_(
|
||||
nan=float("-inf"),
|
||||
posinf=float("inf"),
|
||||
neginf=float("-inf")))
|
||||
logits[:,
|
||||
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
|
||||
lora_logits.shape[1]] = lora_logits
|
||||
|
||||
_apply_lora(
|
||||
hidden_states,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[1]],
|
||||
logits,
|
||||
)
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :self.base_layer.vocab_size]
|
||||
|
||||
return logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return type(self.base_layer).forward(self, *args, **kwargs)
|
||||
|
||||
|
||||
def from_layer(
|
||||
layer: nn.Module,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA:
|
||||
supported_layer_types = {
|
||||
VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA,
|
||||
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
||||
QKVParallelLinear: QKVParallelLinearWithLora,
|
||||
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
||||
RowParallelLinear: RowParallelLinearWithLoRA,
|
||||
}
|
||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||
if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck
|
||||
ret = lora_layer_type(layer)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
return layer
|
||||
|
||||
|
||||
def from_layer_sampler(
|
||||
layer: Sampler,
|
||||
lm_head: ParallelLMHead,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> SamplerWithLoRA:
|
||||
ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype,
|
||||
lm_head.weight.device)
|
||||
ret.create_lora_weights(max_loras, lora_config, model_config)
|
||||
return ret
|
||||
160
vllm/lora/lora.py
Normal file
160
vllm/lora/lora.py
Normal file
@@ -0,0 +1,160 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from vllm.utils import in_wsl
|
||||
|
||||
|
||||
class LoRALayerWeights:
|
||||
"""LoRA weights for a layer composed of two low rank matrixes."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_name: str,
|
||||
rank: int,
|
||||
lora_alpha: int,
|
||||
lora_a: torch.Tensor,
|
||||
lora_b: torch.Tensor,
|
||||
embeddings_tensor: Optional[torch.Tensor] = None,
|
||||
scaling: Optional[float] = None,
|
||||
) -> None:
|
||||
self.module_name = module_name
|
||||
self.rank = rank
|
||||
self.lora_alpha = lora_alpha
|
||||
self.lora_a = lora_a
|
||||
self.lora_b = lora_b
|
||||
self.embeddings_tensor = embeddings_tensor
|
||||
|
||||
if scaling is None:
|
||||
self.scaling = self.lora_alpha / self.rank
|
||||
else:
|
||||
self.scaling = scaling
|
||||
|
||||
def optimize(self) -> "LoRALayerWeights":
|
||||
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||
if self.scaling == 1:
|
||||
return
|
||||
self.lora_b *= self.scaling
|
||||
self.scaling = 1
|
||||
return self
|
||||
|
||||
@property
|
||||
def input_dim(self) -> int:
|
||||
return self.lora_a.shape[0]
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
return self.lora_b.shape[1]
|
||||
|
||||
@property
|
||||
def is_packed(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def extra_vocab_size(self) -> int:
|
||||
return self.embeddings_tensor.shape[
|
||||
0] if self.embeddings_tensor is not None else 0
|
||||
|
||||
@classmethod
|
||||
def create_dummy_lora_weights(
|
||||
cls,
|
||||
module_name: str,
|
||||
input_dim: int,
|
||||
output_dim: int,
|
||||
rank: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights":
|
||||
pin_memory = str(device) == "cpu" and not in_wsl()
|
||||
lora_a = torch.zeros([input_dim, rank],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
lora_b = torch.zeros([rank, output_dim],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
embeddings_tensor = torch.rand(
|
||||
10,
|
||||
embeddings_tensor_dim,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory) if embeddings_tensor_dim else None
|
||||
return cls(
|
||||
module_name,
|
||||
rank=rank,
|
||||
lora_alpha=1,
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
embeddings_tensor=embeddings_tensor,
|
||||
)
|
||||
|
||||
|
||||
class PackedLoRALayerWeights(LoRALayerWeights):
|
||||
"""LoRA used for packed layers (eg. qkv_proj)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module_name: str,
|
||||
rank: int,
|
||||
lora_alphas: List[int],
|
||||
lora_a: List[torch.Tensor],
|
||||
lora_b: List[torch.Tensor],
|
||||
scaling: Optional[List[float]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
module_name=module_name,
|
||||
rank=rank,
|
||||
lora_alpha=0,
|
||||
lora_a=lora_a,
|
||||
lora_b=lora_b,
|
||||
scaling=scaling,
|
||||
embeddings_tensor=None,
|
||||
)
|
||||
self.lora_alphas = lora_alphas
|
||||
if scaling is None:
|
||||
self.scaling = [
|
||||
lora_alpha / self.rank for lora_alpha in self.lora_alphas
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights":
|
||||
"""Pack a list of LoRAs into a single LoRA.
|
||||
|
||||
If LoRA is None, it signifies that the submodule does not have a LoRA.
|
||||
"""
|
||||
first_lora = next(lora for lora in loras if lora is not None)
|
||||
for lora in loras:
|
||||
if lora is None:
|
||||
continue
|
||||
lora.optimize()
|
||||
rank = first_lora.rank
|
||||
module_name = first_lora.module_name
|
||||
obj = cls(
|
||||
module_name,
|
||||
rank,
|
||||
[lora.lora_alpha if lora is not None else None for lora in loras],
|
||||
[lora.lora_a if lora is not None else None for lora in loras],
|
||||
[lora.lora_b if lora is not None else None for lora in loras],
|
||||
scaling=[1 if lora is not None else None for lora in loras])
|
||||
return obj
|
||||
|
||||
def optimize(self) -> "PackedLoRALayerWeights":
|
||||
"""Optimize the LoRA by merging the scaling into lora_b."""
|
||||
for i in range(len(self.lora_b)):
|
||||
if self.scaling[i] == 1 or self.lora_b[i] is None:
|
||||
continue
|
||||
self.lora_b[i] *= self.scaling[i]
|
||||
self.scaling[i] = 1
|
||||
return self
|
||||
|
||||
@property
|
||||
def input_dim(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def output_dim(self) -> int:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
def is_packed(self) -> bool:
|
||||
return True
|
||||
620
vllm/lora/models.py
Normal file
620
vllm/lora/models.py
Normal file
@@ -0,0 +1,620 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type)
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoRAConfig
|
||||
from vllm.utils import LRUCache, in_wsl
|
||||
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GLOBAL_LORA_ID = 0
|
||||
|
||||
|
||||
def convert_mapping(
|
||||
mapping: LoRAMapping, lora_index_to_id: List[Optional[int]],
|
||||
max_loras: int, vocab_size: int, extra_vocab_size: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
|
||||
"""Converts LoRAMapping to index tensors.
|
||||
|
||||
Args:
|
||||
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
|
||||
lora_index_to_id: List mapping LoRA ids to LoRA indices.
|
||||
max_loras: Maximum number of LoRAs.
|
||||
vocab_size: Model vocab size.
|
||||
extra_vocab_size: Extra vocab size each LoRA can have.
|
||||
|
||||
Returns:
|
||||
A tuple of tensors:
|
||||
base_indices: Tensor of shape [batch_size] mapping batch rows to
|
||||
LoRA indices.
|
||||
sampler_indices: Tensor of shape [batch_size] mapping requests to
|
||||
LoRA indices for sampler. For generation, this will be the
|
||||
same as base_indicies. For prefill, this will map requests
|
||||
to LoRA indices.
|
||||
sampler_indices_padded: Tensor of shape [batch_size] mapping
|
||||
requests to LoRA indices for sampler with padding.
|
||||
Same as sampler_indicies, but -1 is replaced with
|
||||
max_loras.
|
||||
embeddings_indices: Tensor of shape [2, batch_size] mapping
|
||||
requests to embedding indices. First row is for embeddings
|
||||
added by the LoRAs, second row is for the LoRA.lora_a
|
||||
embeddings.
|
||||
indices_len: List of lengths of the above tensors.
|
||||
"""
|
||||
indices = list(mapping.index_mapping).copy()
|
||||
embedding_indices = indices.copy()
|
||||
lora_indices = indices.copy()
|
||||
prompt_mapping = [
|
||||
lora_index_to_id.index(x) if x > 0 else -1
|
||||
for x in mapping.prompt_mapping
|
||||
]
|
||||
lora_idx = None
|
||||
for i in range(len(indices)):
|
||||
# TODO index can be slow. optimize
|
||||
lora_idx = (lora_index_to_id.index(indices[i])
|
||||
if indices[i] > 0 else -1)
|
||||
embedding_indices[i] = lora_idx if indices[i] > 0 else 0
|
||||
indices[i] = i
|
||||
lora_indices[i] = lora_idx
|
||||
|
||||
indices = torch.tensor([indices, lora_indices, embedding_indices],
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
prompt_mapping = torch.tensor(prompt_mapping,
|
||||
device="cuda",
|
||||
dtype=torch.long)
|
||||
embeddings_indices = torch.stack([
|
||||
indices[2] * extra_vocab_size,
|
||||
indices[2] * (vocab_size + extra_vocab_size)
|
||||
])
|
||||
embeddings_indices[embeddings_indices == -1] = max_loras - 1
|
||||
base_indices = indices[1]
|
||||
sampler_indices = prompt_mapping
|
||||
sampler_indices_padded = sampler_indices.clone()
|
||||
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
||||
sampler_indices_padded = (
|
||||
torch.arange(
|
||||
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
|
||||
(sampler_indices_padded * len(sampler_indices_padded)))
|
||||
indices_len = (base_indices.shape[-1], sampler_indices.shape[-1],
|
||||
sampler_indices_padded.shape[-1],
|
||||
embeddings_indices.shape[-1])
|
||||
|
||||
return (base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices, indices_len)
|
||||
|
||||
|
||||
def get_lora_id():
|
||||
global _GLOBAL_LORA_ID
|
||||
_GLOBAL_LORA_ID += 1
|
||||
return _GLOBAL_LORA_ID
|
||||
|
||||
|
||||
class LoRAModel:
|
||||
"""A LoRA fine-tuned model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lora_model_id: int,
|
||||
rank: int,
|
||||
loras: Dict[str, LoRALayerWeights],
|
||||
) -> None:
|
||||
self.id = lora_model_id
|
||||
assert (lora_model_id >
|
||||
0), f"a valid lora id should be greater than 0, got {self.id}"
|
||||
self.rank = rank
|
||||
self.loras: Dict[str, LoRALayerWeights] = loras
|
||||
|
||||
@property
|
||||
def extra_vocab_size(self) -> int:
|
||||
return max(lora.extra_vocab_size
|
||||
for lora in self.loras.values()) if self.loras else 0
|
||||
|
||||
def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]:
|
||||
"""Get LoRA for a given module by name"""
|
||||
return self.loras.get(module_name, None)
|
||||
|
||||
# (yard1): TODO see if we can derive target_embedding_padding automatically
|
||||
@classmethod
|
||||
def from_lora_tensors(
|
||||
cls,
|
||||
lora_model_id: int,
|
||||
rank: int,
|
||||
lora_alpha: int,
|
||||
tensors: Dict[str, torch.Tensor],
|
||||
device: str = "cuda",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
embeddings: Optional[Dict[str, torch.Tensor]] = None,
|
||||
target_embedding_padding: Optional[int] = None,
|
||||
embedding_modules: Optional[Dict[str, str]] = None,
|
||||
embedding_padding_modules: Optional[List[str]] = None,
|
||||
) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a dictionary of tensors."""
|
||||
pin_memory = str(device) == "cpu" and not in_wsl()
|
||||
loras: Dict[str, LoRALayerWeights] = {}
|
||||
for tensor_name, tensor in tensors.items():
|
||||
module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name)
|
||||
if module_name not in loras:
|
||||
lora_embeddings_tensor = None
|
||||
if embeddings:
|
||||
embeddings_module = next(
|
||||
(k for k in embedding_modules if k in module_name),
|
||||
None)
|
||||
if embeddings_module:
|
||||
lora_embeddings_tensor = embeddings[
|
||||
embedding_modules[embeddings_module]].to(
|
||||
device=device, dtype=dtype)
|
||||
if pin_memory:
|
||||
lora_embeddings_tensor = (
|
||||
lora_embeddings_tensor.pin_memory())
|
||||
loras[module_name] = LoRALayerWeights(module_name, rank,
|
||||
lora_alpha, None, None,
|
||||
lora_embeddings_tensor)
|
||||
if is_lora_a:
|
||||
loras[module_name].lora_a = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
if pin_memory:
|
||||
loras[module_name].lora_a = loras[
|
||||
module_name].lora_a.pin_memory()
|
||||
else:
|
||||
loras[module_name].lora_b = tensor.to(device=device,
|
||||
dtype=dtype).t()
|
||||
if any(name in module_name
|
||||
for name in embedding_padding_modules
|
||||
) and target_embedding_padding is not None:
|
||||
lora_b = loras[module_name].lora_b
|
||||
assert target_embedding_padding >= lora_b.shape[1]
|
||||
addition = target_embedding_padding - lora_b.shape[1]
|
||||
loras[module_name].lora_b = torch.nn.functional.pad(
|
||||
lora_b, (0, addition))
|
||||
if pin_memory:
|
||||
loras[module_name].lora_b = loras[
|
||||
module_name].lora_b.pin_memory()
|
||||
|
||||
for lora in loras.values():
|
||||
lora.optimize()
|
||||
return cls(lora_model_id, rank, loras)
|
||||
|
||||
@classmethod
|
||||
def from_local_checkpoint(
|
||||
cls,
|
||||
lora_dir: str,
|
||||
lora_model_id: Optional[int] = None,
|
||||
device: str = "cuda",
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
target_embedding_padding: Optional[int] = None,
|
||||
embedding_modules: Optional[Dict[str, str]] = None,
|
||||
embedding_padding_modules: Optional[List[str]] = None,
|
||||
) -> "LoRAModel":
|
||||
"""Create a LoRAModel from a local checkpoint."""
|
||||
lora_config_path = os.path.join(lora_dir, "adapter_config.json")
|
||||
lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors")
|
||||
lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin")
|
||||
new_embeddings_tensor_path = os.path.join(
|
||||
lora_dir, "new_embeddings.safetensors")
|
||||
new_embeddings_bin_file_path = os.path.join(lora_dir,
|
||||
"new_embeddings.bin")
|
||||
if os.path.isfile(lora_tensor_path):
|
||||
tensors = safetensors.torch.load_file(lora_tensor_path)
|
||||
elif os.path.isfile(lora_bin_file_path):
|
||||
tensors = torch.load(lora_bin_file_path)
|
||||
else:
|
||||
raise ValueError(f"{lora_dir} doesn't contain tensors")
|
||||
|
||||
embeddings = None
|
||||
if os.path.isfile(new_embeddings_tensor_path):
|
||||
embeddings = safetensors.torch.load_file(
|
||||
new_embeddings_tensor_path)
|
||||
elif os.path.isfile(new_embeddings_bin_file_path):
|
||||
embeddings = torch.load(new_embeddings_bin_file_path)
|
||||
|
||||
with open(lora_config_path) as f:
|
||||
config = json.load(f)
|
||||
rank = config["r"]
|
||||
lora_alpha = config["lora_alpha"]
|
||||
return cls.from_lora_tensors(
|
||||
lora_model_id=get_lora_id()
|
||||
if lora_model_id is None else lora_model_id,
|
||||
rank=rank,
|
||||
lora_alpha=lora_alpha,
|
||||
tensors=tensors,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
embeddings=embeddings,
|
||||
target_embedding_padding=target_embedding_padding,
|
||||
embedding_modules=embedding_modules,
|
||||
embedding_padding_modules=embedding_padding_modules,
|
||||
)
|
||||
|
||||
|
||||
class LoRAModelManager:
|
||||
"""A manager that manages multiple LoRA-fine-tuned models."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
):
|
||||
"""Create a LoRAModelManager and adapter for a given model.
|
||||
|
||||
Args:
|
||||
model: the model to be adapted.
|
||||
max_num_seqs: the maximum number of sequences model can run in a
|
||||
single batch.
|
||||
max_num_batched_tokens: the maximum number of tokens model can run
|
||||
in a single batch.
|
||||
vocab_size: the vocab size of the model.
|
||||
lora_config: the LoRA configuration.
|
||||
"""
|
||||
self.lora_config = lora_config
|
||||
self.max_num_seqs = max_num_seqs
|
||||
assert self.capacity >= self.lora_slots
|
||||
self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8
|
||||
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
|
||||
self.vocab_size = vocab_size
|
||||
self.base_indices = torch.empty(self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.embeddings_indices = torch.empty(2,
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.offsets = []
|
||||
# 4 is the number of indicies tensors defined above
|
||||
# base_indices, sampler_indices, sampler_indices_padded,
|
||||
# embeddings_indices
|
||||
self.indices_len = [None] * 4
|
||||
|
||||
self.model: nn.Module = model
|
||||
if hasattr(self.model, "supported_lora_modules"):
|
||||
self.supported_lora_modules = copy.deepcopy(
|
||||
self.model.supported_lora_modules)
|
||||
self.packed_modules_mapping = copy.deepcopy(
|
||||
self.model.packed_modules_mapping)
|
||||
self.packed_modules: Dict[str, List[str]] = {}
|
||||
self.modules: Dict[str, "BaseLayerWithLoRA"] = {}
|
||||
self._registered_loras: Dict[int, LoRAModel] = {}
|
||||
# Dict instead of a Set for compatibility with LRUCache.
|
||||
self._active_loras: Dict[int, None] = {}
|
||||
self._last_mapping = None
|
||||
self._create_lora_modules()
|
||||
self.model.lora_manager = self
|
||||
|
||||
@property
|
||||
def capacity(self) -> int:
|
||||
return self.lora_config.max_cpu_loras
|
||||
|
||||
@property
|
||||
def lora_slots(self) -> int:
|
||||
return self.lora_config.max_loras
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._registered_loras)
|
||||
|
||||
def activate_lora(
|
||||
self,
|
||||
lora_id: int,
|
||||
) -> bool:
|
||||
"""Move LoRA into a GPU buffer to be used in the forward pass."""
|
||||
if lora_id in self._active_loras:
|
||||
return False
|
||||
first_free_slot = next(
|
||||
((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id)
|
||||
if lora_id is None), None)
|
||||
if first_free_slot is None:
|
||||
raise ValueError("No free lora slots")
|
||||
index, _ = first_free_slot
|
||||
self._active_loras[lora_id] = None
|
||||
lora_model = self._registered_loras[lora_id]
|
||||
logger.debug(
|
||||
f"Activating LoRA. int id: {lora_model.id}, slot index: {index}")
|
||||
self.lora_index_to_id[index] = lora_model.id
|
||||
for module_name, module in self.modules.items():
|
||||
module_lora = lora_model.get_lora(module_name)
|
||||
if module_lora:
|
||||
module_lora.optimize()
|
||||
module.set_lora(index, module_lora.lora_a, module_lora.lora_b,
|
||||
module_lora.embeddings_tensor)
|
||||
else:
|
||||
module.reset_lora(index)
|
||||
return True
|
||||
|
||||
def _deactivate_lora(self, lora_id: int):
|
||||
try:
|
||||
index = self.lora_index_to_id.index(lora_id)
|
||||
self.lora_index_to_id[index] = None
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
def deactivate_lora(self, lora_id: int) -> bool:
|
||||
"""Remove a LoRA from a GPU buffer."""
|
||||
if lora_id in self._active_loras:
|
||||
self._deactivate_lora(lora_id)
|
||||
self._active_loras.pop(lora_id)
|
||||
return True
|
||||
return False
|
||||
|
||||
def _add_lora(self, lora: LoRAModel) -> bool:
|
||||
self._create_merged_loras_inplace(lora)
|
||||
self._registered_loras[lora.id] = lora
|
||||
|
||||
def add_lora(self, lora: LoRAModel) -> bool:
|
||||
"""Add a LoRAModel to the manager CPU cache."""
|
||||
if lora.id not in self._registered_loras:
|
||||
if len(self._registered_loras) >= self.capacity:
|
||||
raise RuntimeError("No free LoRA slots.")
|
||||
self._add_lora(lora)
|
||||
return True
|
||||
return False
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
"""Remove a LoRAModel from the manager CPU cache."""
|
||||
# TODO: should we check active lora?
|
||||
self.deactivate_lora(lora_id)
|
||||
return bool(self._registered_loras.pop(lora_id, None))
|
||||
|
||||
# TODO see if this can be vectorized
|
||||
def _set_lora_mapping(self, mapping: LoRAMapping) -> None:
|
||||
(base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
|
||||
self.lora_slots + 1, self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size)
|
||||
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
|
||||
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
||||
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
|
||||
sampler_indices_padded)
|
||||
self.embeddings_indices[:embeddings_indices.
|
||||
shape[0], :embeddings_indices.shape[1]].copy_(
|
||||
embeddings_indices)
|
||||
# Maintain the reference
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None:
|
||||
if self._last_mapping != lora_mapping:
|
||||
self._set_lora_mapping(lora_mapping)
|
||||
self._last_mapping = lora_mapping
|
||||
|
||||
def list_loras(self) -> Dict[int, LoRAModel]:
|
||||
"""List all registered LoRAModels."""
|
||||
return dict(self._registered_loras)
|
||||
|
||||
def get_lora(self, lora_id: int) -> Optional[LoRAModel]:
|
||||
return self._registered_loras.get(lora_id, None)
|
||||
|
||||
def remove_all_loras(self) -> bool:
|
||||
"""Remove all LoRAModels from the manager."""
|
||||
self._registered_loras.clear()
|
||||
self.lora_index_to_id = [None] * self.lora_slots
|
||||
self._active_loras.clear()
|
||||
|
||||
def _create_lora_modules(self):
|
||||
for module_name, module in self.model.named_modules():
|
||||
if not self._match_target_modules(module_name):
|
||||
continue
|
||||
|
||||
new_module = replace_submodule(
|
||||
self.model, module_name,
|
||||
from_layer(module, self.lora_slots, self.lora_config,
|
||||
self.model.config))
|
||||
# (yard1): TODO make this more robust
|
||||
if "lm_head" in module_name:
|
||||
sampler_module = self.model.get_submodule("sampler")
|
||||
new_module = replace_submodule(
|
||||
self.model, "sampler",
|
||||
from_layer_sampler(sampler_module, module, self.lora_slots,
|
||||
self.lora_config, self.model.config))
|
||||
self.register_module(module_name, new_module)
|
||||
self._register_packed_modules(module_name)
|
||||
new_module.set_mapping(self.base_indices, self.sampler_indices,
|
||||
self.sampler_indices_padded,
|
||||
self.embeddings_indices, self.indices_len)
|
||||
|
||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||
assert isinstance(module, BaseLayerWithLoRA)
|
||||
self.modules[module_name] = module
|
||||
|
||||
def create_dummy_lora(
|
||||
self,
|
||||
lora_id: int,
|
||||
rank: int,
|
||||
embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel:
|
||||
"""Create zero-initialized LoRAModel for warmup."""
|
||||
model = LoRAModel(lora_id, rank, {})
|
||||
for module_name, module in self.model.named_modules():
|
||||
if not self._match_target_modules(module_name) or not isinstance(
|
||||
module, BaseLayerWithLoRA):
|
||||
continue
|
||||
parts = module_name.split(".")
|
||||
if module_name not in self.packed_modules:
|
||||
if parts[-1] in embedding_modules:
|
||||
input_dim = (module.base_layer.org_vocab_size +
|
||||
self.lora_config.lora_extra_vocab_size if
|
||||
hasattr(module.base_layer, "org_vocab_size")
|
||||
else module.base_layer.weight.shape[1])
|
||||
output_dim = module.base_layer.embedding_dim if hasattr(
|
||||
module.base_layer,
|
||||
"embedding_dim") else module.base_layer.weight.shape[0]
|
||||
embeddings_tensor_dim = (module.base_layer.embedding_dim if
|
||||
hasattr(module.base_layer,
|
||||
"embedding_dim") else
|
||||
module.base_layer.weight.shape[1])
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name,
|
||||
input_dim,
|
||||
output_dim,
|
||||
rank,
|
||||
module.lora_a_stacked.dtype,
|
||||
"cpu",
|
||||
embeddings_tensor_dim=embeddings_tensor_dim)
|
||||
else:
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name,
|
||||
module.lora_a_stacked.shape[-1],
|
||||
module.lora_b_stacked.shape[-2],
|
||||
rank,
|
||||
module.lora_a_stacked.dtype,
|
||||
"cpu",
|
||||
)
|
||||
lora.optimize()
|
||||
else:
|
||||
parts = module_name.split(".")
|
||||
replacements = self.packed_modules_mapping[parts[-1]]
|
||||
subloras = []
|
||||
for i, r in enumerate(replacements):
|
||||
lora = LoRALayerWeights.create_dummy_lora_weights(
|
||||
module_name + "." + r,
|
||||
module.lora_a_stacked[i].shape[-1],
|
||||
module.lora_b_stacked[i].shape[-2],
|
||||
rank,
|
||||
module.lora_a_stacked[i].dtype,
|
||||
"cpu",
|
||||
)
|
||||
lora.optimize()
|
||||
subloras.append(lora)
|
||||
lora = PackedLoRALayerWeights.pack(subloras)
|
||||
model.loras[module_name] = lora
|
||||
return model
|
||||
|
||||
def _match_target_modules(self, module_name: str):
|
||||
return any(
|
||||
re.match(
|
||||
r".*\.{target_module}$".format(target_module=target_module),
|
||||
module_name) or target_module == module_name
|
||||
for target_module in self.supported_lora_modules)
|
||||
|
||||
def _register_packed_modules(self, module_full_name: str) -> None:
|
||||
parts = module_full_name.split(".")
|
||||
module_name = parts[-1]
|
||||
replacements = self.packed_modules_mapping.get(module_name)
|
||||
if not replacements:
|
||||
return
|
||||
prefix = ".".join(parts[:-1])
|
||||
self.packed_modules[module_full_name] = [
|
||||
prefix + "." + r if prefix else r for r in replacements
|
||||
]
|
||||
|
||||
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
||||
for module_name, new_module_names in self.packed_modules.items():
|
||||
replacement_loras = []
|
||||
has_replacement = False
|
||||
for r in new_module_names:
|
||||
lora = lora_model.get_lora(r)
|
||||
replacement_loras.append(lora)
|
||||
if lora:
|
||||
has_replacement = True
|
||||
if not has_replacement:
|
||||
continue
|
||||
for i in range(len(replacement_loras)):
|
||||
if replacement_loras[i]:
|
||||
continue
|
||||
replacement_loras[i] = None
|
||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||
replacement_loras)
|
||||
|
||||
|
||||
class LoRALRUCache(LRUCache):
|
||||
|
||||
def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable],
|
||||
None]):
|
||||
super().__init__(capacity)
|
||||
self.deactivate_lora_fn = deactivate_lora_fn
|
||||
|
||||
def _on_remove(self, key: Hashable, value: Any):
|
||||
logger.debug(f"Removing LoRA. int id: {key}")
|
||||
self.deactivate_lora_fn(key)
|
||||
return super()._on_remove(key, value)
|
||||
|
||||
|
||||
class LRUCacheLoRAModelManager(LoRAModelManager):
|
||||
"""A model manager that manages multiple LoRAs with LRU cache."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
):
|
||||
super().__init__(model, max_num_seqs, max_num_batched_tokens,
|
||||
vocab_size, lora_config)
|
||||
self._registered_loras: LoRALRUCache = LoRALRUCache(
|
||||
self.capacity, self.deactivate_lora)
|
||||
self._active_loras: LoRALRUCache = LoRALRUCache(
|
||||
self.lora_slots, self._deactivate_lora)
|
||||
|
||||
def list_loras(self) -> Dict[int, LoRAModel]:
|
||||
"""List all registered LoRAModels."""
|
||||
return dict(self._registered_loras.cache)
|
||||
|
||||
def add_lora(self, lora: LoRAModel) -> bool:
|
||||
"""Add a LoRAModel to the manager."""
|
||||
if lora.id not in self._registered_loras:
|
||||
self._add_lora(lora)
|
||||
was_added = True
|
||||
else:
|
||||
# We always touch to update the LRU cache order
|
||||
self._registered_loras.touch(lora.id)
|
||||
was_added = False
|
||||
return was_added
|
||||
|
||||
def activate_lora(
|
||||
self,
|
||||
lora_id: int,
|
||||
) -> bool:
|
||||
if lora_id not in self._active_loras and len(
|
||||
self._active_loras) >= self.lora_slots:
|
||||
self._active_loras.remove_oldest()
|
||||
result = super().activate_lora(lora_id)
|
||||
# We always touch to update the LRU cache order
|
||||
self._active_loras.touch(lora_id)
|
||||
return result
|
||||
|
||||
def remove_oldest_lora(self) -> bool:
|
||||
if len(self._registered_loras) > 0:
|
||||
self._registered_loras.remove_oldest()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def create_lora_manager(
|
||||
model: nn.Module,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager,
|
||||
**kwargs) -> LoRAModelManager:
|
||||
"""Create a LoRA adapter for a given model."""
|
||||
if not hasattr(model, "supported_lora_modules"):
|
||||
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
|
||||
lora_manager = lora_manager_cls(
|
||||
model=model,
|
||||
max_num_seqs=max_num_seqs,
|
||||
max_num_batched_tokens=max_num_batched_tokens,
|
||||
vocab_size=vocab_size,
|
||||
lora_config=lora_config,
|
||||
**kwargs)
|
||||
return lora_manager
|
||||
170
vllm/lora/punica.py
Normal file
170
vllm/lora/punica.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# Based on code from https://github.com/punica-ai/punica
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def _raise_import_error(e):
|
||||
if torch.cuda.get_device_capability() < (8, 0):
|
||||
raise ImportError(
|
||||
"punica LoRA kernels require compute capability >= 8.0") from e
|
||||
else:
|
||||
raise ImportError(
|
||||
"punica LoRA kernels could not be imported. If you built vLLM "
|
||||
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
|
||||
"was set.") from e
|
||||
|
||||
|
||||
def bgmv(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
):
|
||||
"""
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
|
||||
matrices.
|
||||
indicies: Shape: `[B]`. Indices of the weight matrices.
|
||||
layer_idx: Layer index of the weight matrices.
|
||||
scale: Scaling factor.
|
||||
"""
|
||||
try:
|
||||
import vllm._punica_C as punica_kernels
|
||||
except ImportError as e:
|
||||
_raise_import_error(e)
|
||||
|
||||
punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
|
||||
|
||||
|
||||
def add_lora(y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
||||
LoRA A matrices.
|
||||
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
||||
LoRA B matrices.
|
||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||
layer_idx: Layer index of LoRA weights.
|
||||
scale: Scaling factor.
|
||||
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
|
||||
"""
|
||||
try:
|
||||
import vllm._punica_C as punica_kernels
|
||||
except ImportError as e:
|
||||
_raise_import_error(e)
|
||||
|
||||
r = wb_t_all.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default to avoid
|
||||
# numerical inaccuracies that would otherwise happen
|
||||
# due to downcasting.
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
|
||||
punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx,
|
||||
scale)
|
||||
|
||||
|
||||
def add_lora_slice(y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Same as `add_lora` but you can operate on slices of y.
|
||||
Pass whole y, define y_offset and y_slice_size.
|
||||
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
||||
LoRA A matrices.
|
||||
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
||||
LoRA B matrices.
|
||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||
layer_idx: Layer index of LoRA weights.
|
||||
scale: Scaling factor.
|
||||
y_offset: Offset to apply to the starting column of y.
|
||||
y_slice_size: Size of the y column slice.
|
||||
"""
|
||||
try:
|
||||
import vllm._punica_C as punica_kernels
|
||||
except ImportError as e:
|
||||
_raise_import_error(e)
|
||||
|
||||
r = wb_t_all.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default to avoid
|
||||
# numerical inaccuracies that would otherwise happen
|
||||
# due to downcasting.
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
punica_kernels.dispatch_bgmv_low_level(
|
||||
buffer,
|
||||
x,
|
||||
wa_t_all,
|
||||
indicies,
|
||||
layer_idx,
|
||||
1.0,
|
||||
x.size(1),
|
||||
buffer.size(1),
|
||||
0,
|
||||
)
|
||||
punica_kernels.dispatch_bgmv_low_level(
|
||||
y,
|
||||
buffer,
|
||||
wb_t_all,
|
||||
indicies,
|
||||
layer_idx,
|
||||
scale,
|
||||
buffer.size(1),
|
||||
y_slice_size,
|
||||
y_offset,
|
||||
)
|
||||
32
vllm/lora/request.py
Normal file
32
vllm/lora/request.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRARequest:
|
||||
"""
|
||||
Request for a LoRA adapter.
|
||||
|
||||
Note that this class should be be used internally. For online
|
||||
serving, it is recommended to not allow users to use this class but
|
||||
instead provide another layer of abstraction to prevent users from
|
||||
accessing unauthorized LoRA adapters.
|
||||
|
||||
lora_int_id must be globally unique for a given adapter.
|
||||
This is currently not enforced in vLLM.
|
||||
"""
|
||||
|
||||
lora_name: str
|
||||
lora_int_id: int
|
||||
lora_local_path: str
|
||||
|
||||
def __post_init__(self):
|
||||
if self.lora_int_id < 1:
|
||||
raise ValueError(
|
||||
f"lora_int_id must be > 0, got {self.lora_int_id}")
|
||||
|
||||
def __eq__(self, value: object) -> bool:
|
||||
return isinstance(
|
||||
value, LoRARequest) and self.lora_int_id == value.lora_int_id
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return self.lora_int_id
|
||||
39
vllm/lora/utils.py
Normal file
39
vllm/lora/utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
from typing import Tuple
|
||||
|
||||
from torch import nn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def replace_submodule(model: nn.Module, module_name: str,
|
||||
new_module: nn.Module) -> nn.Module:
|
||||
"""Replace a submodule in a model with a new module."""
|
||||
parent = model.get_submodule(".".join(module_name.split(".")[:-1]))
|
||||
target_name = module_name.split(".")[-1]
|
||||
setattr(parent, target_name, new_module)
|
||||
return new_module
|
||||
|
||||
|
||||
def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]:
|
||||
"""Parse the name of lora weights.
|
||||
|
||||
args:
|
||||
name: the name of the fine-tuned LoRA, e.g.
|
||||
base_model.model.dense1.weight
|
||||
return:
|
||||
Tuple(module_name, is_lora_a):
|
||||
module_name: the name of the module, e.g. model.dense1,
|
||||
is_lora_a whether the tensor is lora_a or lora_b.
|
||||
"""
|
||||
parts = name.split(".")
|
||||
assert parts[0] == "base_model"
|
||||
assert parts[1] == "model"
|
||||
if parts[-1] == "weight":
|
||||
assert parts[-2] == "lora_A" or parts[-2] == "lora_B"
|
||||
return ".".join(parts[2:-2]), parts[-2] == "lora_A"
|
||||
|
||||
if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B":
|
||||
return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A"
|
||||
|
||||
raise ValueError(f"{name} is unsupported format")
|
||||
238
vllm/lora/worker_manager.py
Normal file
238
vllm/lora/worker_manager.py
Normal file
@@ -0,0 +1,238 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod, abstractproperty
|
||||
from typing import Any, Dict, List, Optional, Set, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager, create_lora_manager)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.config import LoRAConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbstractWorkerLoRAManager(ABC):
|
||||
"""Abstract class for managing LoRA models on the worker side."""
|
||||
|
||||
def __init__(self, max_num_seqs: int, max_num_batched_tokens: int,
|
||||
vocab_size: int, lora_config: LoRAConfig,
|
||||
device: torch.device):
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.device = device
|
||||
self.lora_config = lora_config
|
||||
|
||||
@abstractproperty
|
||||
def is_enabled(self) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def set_active_loras(self, lora_requests: List[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def remove_all_loras(self) -> bool:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def list_loras(self) -> Set[int]:
|
||||
...
|
||||
|
||||
|
||||
class WorkerLoRAManager(AbstractWorkerLoRAManager):
|
||||
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
||||
|
||||
Every request, the requested LoRAs will be loaded (unless they are already
|
||||
loaded), and every other LoRA will be unloaded."""
|
||||
|
||||
_lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
device: torch.device,
|
||||
embedding_modules: Dict[str, str],
|
||||
embedding_padding_modules: List[str],
|
||||
lora_model_cls: Type[LoRAModel] = LoRAModel,
|
||||
):
|
||||
self._lora_manager: Optional[LoRAModelManager] = None
|
||||
self._lora_model_cls = lora_model_cls
|
||||
self.embedding_modules = embedding_modules
|
||||
self.embedding_padding_modules = embedding_padding_modules
|
||||
super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size,
|
||||
lora_config, device)
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
return True
|
||||
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
vocab_size=self.vocab_size,
|
||||
lora_config=self.lora_config,
|
||||
lora_manager_cls=self._lora_manager_cls,
|
||||
)
|
||||
self._lora_manager: LoRAModelManager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
def set_active_loras(self, lora_requests: List[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
self._apply_loras(lora_requests)
|
||||
self._lora_manager.set_lora_mapping(lora_mapping)
|
||||
|
||||
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
|
||||
loras_that_exist = self.list_loras()
|
||||
loras_map = {
|
||||
lora_request.lora_int_id: lora_request
|
||||
for lora_request in lora_requests if lora_request
|
||||
}
|
||||
if len(loras_map) > self._lora_manager.lora_slots:
|
||||
raise RuntimeError(
|
||||
f"Number of requested LoRAs ({len(loras_map)}) is greater "
|
||||
"than the number of GPU LoRA slots "
|
||||
f"({self._lora_manager.lora_slots}).")
|
||||
|
||||
new_loras = set(loras_map)
|
||||
loras_to_add = new_loras - loras_that_exist
|
||||
loras_to_remove = loras_that_exist - new_loras
|
||||
|
||||
for lora_id in loras_to_remove:
|
||||
self.remove_lora(lora_id)
|
||||
|
||||
for lora_id in loras_to_add:
|
||||
self.add_lora(loras_map[lora_id])
|
||||
|
||||
def _load_lora(self, lora_request: LoRARequest) -> LoRAModel:
|
||||
try:
|
||||
lora = self._lora_model_cls.from_local_checkpoint(
|
||||
lora_request.lora_local_path,
|
||||
lora_model_id=lora_request.lora_int_id,
|
||||
device="cpu",
|
||||
dtype=self.lora_config.lora_dtype,
|
||||
target_embedding_padding=self.vocab_size +
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
embedding_modules=self.embedding_modules,
|
||||
embedding_padding_modules=self.embedding_padding_modules,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Loading lora {lora_request.lora_local_path} failed") from e
|
||||
if lora.rank > self.lora_config.max_lora_rank:
|
||||
raise ValueError(
|
||||
f"LoRA rank {lora.rank} is greater than max_lora_rank "
|
||||
f"{self.lora_config.max_lora_rank}.")
|
||||
if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size:
|
||||
raise ValueError(
|
||||
f"LoRA added vocab size {lora.extra_vocab_size} is greater than "
|
||||
f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}."
|
||||
)
|
||||
return lora
|
||||
|
||||
def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool:
|
||||
if lora_request.lora_int_id in self.list_loras():
|
||||
return False
|
||||
return self._lora_manager.add_lora(
|
||||
self._lora_manager.create_dummy_lora(lora_request.lora_int_id,
|
||||
rank, self.embedding_modules))
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if lora_request.lora_int_id in self.list_loras():
|
||||
return False
|
||||
lora = self._load_lora(lora_request)
|
||||
loaded = self._lora_manager.add_lora(lora)
|
||||
self._lora_manager.activate_lora(lora.id)
|
||||
return loaded
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self._lora_manager.remove_lora(lora_id)
|
||||
|
||||
def remove_all_loras(self) -> bool:
|
||||
self._lora_manager.remove_all_loras()
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return set(self._lora_manager.list_loras())
|
||||
|
||||
|
||||
class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
|
||||
"""WorkerLoRAManager that manages LoRA models on the worker side.
|
||||
|
||||
Uses an LRU Cache. Every request, the requested LoRAs will be loaded
|
||||
(unless they are already loaded) and least recently used LoRAs will
|
||||
be unloaded if the cache is above capacity."""
|
||||
|
||||
_lora_manager_cls: Type[
|
||||
LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager
|
||||
|
||||
def create_lora_manager(
|
||||
self,
|
||||
model: torch.nn.Module,
|
||||
) -> Any:
|
||||
lora_manager = create_lora_manager(
|
||||
model,
|
||||
lora_manager_cls=self._lora_manager_cls,
|
||||
max_num_seqs=self.max_num_seqs,
|
||||
vocab_size=self.vocab_size,
|
||||
lora_config=self.lora_config,
|
||||
max_num_batched_tokens=self.max_num_batched_tokens,
|
||||
)
|
||||
self._lora_manager: LRUCacheLoRAModelManager = lora_manager
|
||||
return lora_manager.model
|
||||
|
||||
def _apply_loras(self, lora_requests: List[LoRARequest]) -> None:
|
||||
loras_map = {
|
||||
lora_request.lora_int_id: lora_request
|
||||
for lora_request in lora_requests if lora_request
|
||||
}
|
||||
if len(loras_map) > self._lora_manager.lora_slots:
|
||||
raise RuntimeError(
|
||||
f"Number of requested LoRAs ({len(loras_map)}) is greater "
|
||||
"than the number of GPU LoRA slots "
|
||||
f"({self._lora_manager.lora_slots}).")
|
||||
for lora in loras_map.values():
|
||||
self.add_lora(lora)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if lora_request.lora_int_id not in self.list_loras():
|
||||
# Remove before we load the new lora to save memory
|
||||
if len(self._lora_manager) + 1 > self._lora_manager.capacity:
|
||||
self._lora_manager.remove_oldest_lora()
|
||||
lora = self._load_lora(lora_request)
|
||||
loaded = self._lora_manager.add_lora(lora)
|
||||
else:
|
||||
# If the lora is already loaded, just touch it to
|
||||
# update its position in the caches
|
||||
loaded = self._lora_manager.get_lora(lora_request.lora_int_id)
|
||||
self._lora_manager.activate_lora(lora_request.lora_int_id)
|
||||
return loaded
|
||||
10
vllm/model_executor/__init__.py
Normal file
10
vllm/model_executor/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.utils import set_random_seed, get_model
|
||||
|
||||
__all__ = [
|
||||
"InputMetadata",
|
||||
"get_model",
|
||||
"SamplingMetadata",
|
||||
"set_random_seed",
|
||||
]
|
||||
BIN
vllm/model_executor/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/model_executor/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/__pycache__/guided_decoding.cpython-310.pyc
Normal file
BIN
vllm/model_executor/__pycache__/guided_decoding.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/__pycache__/input_metadata.cpython-310.pyc
Normal file
BIN
vllm/model_executor/__pycache__/input_metadata.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/__pycache__/model_loader.cpython-310.pyc
Normal file
BIN
vllm/model_executor/__pycache__/model_loader.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/__pycache__/utils.cpython-310.pyc
Normal file
BIN
vllm/model_executor/__pycache__/utils.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/__pycache__/weight_utils.cpython-310.pyc
Normal file
BIN
vllm/model_executor/__pycache__/weight_utils.cpython-310.pyc
Normal file
Binary file not shown.
99
vllm/model_executor/guided_decoding.py
Normal file
99
vllm/model_executor/guided_decoding.py
Normal file
@@ -0,0 +1,99 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from copy import copy
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from json import dumps as json_dumps
|
||||
from re import escape as regex_escape
|
||||
from typing import Union, Tuple
|
||||
from pydantic import BaseModel
|
||||
|
||||
from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest
|
||||
from vllm.model_executor.guided_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor
|
||||
|
||||
|
||||
class GuidedDecodingMode(Enum):
|
||||
JSON = "json"
|
||||
REGEX = "regex"
|
||||
CHOICE = "choice"
|
||||
|
||||
|
||||
global_thread_pool = None # used for generating logits processor fsm
|
||||
|
||||
|
||||
async def get_guided_decoding_logits_processor(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest],
|
||||
tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]:
|
||||
"""
|
||||
Given an OpenAI-compatible request, check for guided decoding parameters
|
||||
and get the necessary logits processor for the given guide.
|
||||
We cache logit processors by (guide, tokenizer), and on cache hit
|
||||
we make a shallow copy to reuse the same underlying FSM.
|
||||
"""
|
||||
global global_thread_pool
|
||||
guide, mode = _get_guide_and_mode(request)
|
||||
if not guide:
|
||||
return None
|
||||
|
||||
if global_thread_pool is None:
|
||||
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=2)
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
result = await loop.run_in_executor(global_thread_pool,
|
||||
_get_cached_logits_processor, guide,
|
||||
tokenizer, mode)
|
||||
|
||||
logits_processor = copy(result)
|
||||
# reset logits processor's internal state
|
||||
logits_processor.init_state()
|
||||
return logits_processor
|
||||
|
||||
|
||||
def _get_guide_and_mode(
|
||||
request: Union[CompletionRequest, ChatCompletionRequest]
|
||||
) -> Tuple[str, GuidedDecodingMode]:
|
||||
|
||||
if request.guided_json:
|
||||
if not isinstance(request.guided_json, (str, dict, BaseModel)):
|
||||
raise TypeError("JSON schema must be str, dict, or BaseModel")
|
||||
|
||||
json = request.guided_json
|
||||
if isinstance(json, dict):
|
||||
# turn dict into hashable string
|
||||
json = json_dumps(json, sort_keys=True)
|
||||
elif isinstance(json, BaseModel):
|
||||
# use pydantic signature so that different model classes
|
||||
# with the same fields will get hashed the same
|
||||
json = str(json.__signature__)
|
||||
return json, GuidedDecodingMode.JSON
|
||||
|
||||
elif request.guided_regex:
|
||||
if not isinstance(request.guided_regex, str):
|
||||
raise TypeError("Regex must be string")
|
||||
return request.guided_regex, GuidedDecodingMode.REGEX
|
||||
|
||||
elif request.guided_choice:
|
||||
if not isinstance(request.guided_choice, list):
|
||||
raise TypeError("Choices must be a list")
|
||||
|
||||
# choice just uses regex
|
||||
choices = [
|
||||
regex_escape(str(choice)) for choice in request.guided_choice
|
||||
]
|
||||
choices_regex = "(" + "|".join(choices) + ")"
|
||||
return choices_regex, GuidedDecodingMode.CHOICE
|
||||
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_cached_logits_processor(guide: str, tokenizer,
|
||||
mode: GuidedDecodingMode):
|
||||
if mode == GuidedDecodingMode.JSON:
|
||||
return JSONLogitsProcessor(guide, tokenizer)
|
||||
elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE:
|
||||
return RegexLogitsProcessor(guide, tokenizer)
|
||||
else:
|
||||
raise ValueError(f"Unknown guided decoding mode {mode}")
|
||||
129
vllm/model_executor/guided_logits_processors.py
Normal file
129
vllm/model_executor/guided_logits_processors.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Copyright 2024- the Outlines developers
|
||||
# This file is adapted from
|
||||
# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from typing import Union, DefaultDict, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel
|
||||
from outlines.fsm.fsm import RegexFSM
|
||||
from outlines.fsm.json_schema import build_regex_from_schema
|
||||
|
||||
|
||||
class RegexLogitsProcessor:
|
||||
|
||||
def __init__(self, regex_string: str, tokenizer):
|
||||
"""Compile the FSM that drives the regex-structured generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
regex_string
|
||||
A string that represents a regular expression
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
|
||||
"""
|
||||
tokenizer = self.adapt_tokenizer(tokenizer)
|
||||
fsm = RegexFSM(regex_string, tokenizer)
|
||||
self.fsm = fsm
|
||||
|
||||
def init_state(self):
|
||||
"""Initialize the FSM states."""
|
||||
self.fsm_state: DefaultDict[int, int] = defaultdict(int)
|
||||
|
||||
def __call__(self, input_ids: List[int],
|
||||
scores: torch.Tensor) -> torch.Tensor:
|
||||
"""Use the FSM to bias the logits before sampling the next token."""
|
||||
|
||||
seq_id = hash(tuple(input_ids))
|
||||
|
||||
if len(input_ids) == 0:
|
||||
self.init_state()
|
||||
else:
|
||||
last_token = input_ids[-1]
|
||||
last_seq_id = hash(tuple(input_ids[:-1]))
|
||||
self.fsm_state[seq_id] = self.fsm.next_state(
|
||||
self.fsm_state[last_seq_id], last_token)
|
||||
|
||||
allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id])
|
||||
|
||||
mask = torch.full((scores.shape[-1], ),
|
||||
-math.inf,
|
||||
device=scores.device)
|
||||
mask[allowed_tokens] = 0
|
||||
scores.add_(mask)
|
||||
|
||||
return scores
|
||||
|
||||
def adapt_tokenizer(self, tokenizer):
|
||||
"""Adapt vLLM's tokenizer to use to compile the FSM.
|
||||
|
||||
The API of Outlines tokenizers is slightly different to that of
|
||||
`transformers`. In addition we need to handle the missing spaces to
|
||||
Llama's tokenizer to be able to compile FSMs for this model.
|
||||
|
||||
"""
|
||||
tokenizer.vocabulary = tokenizer.get_vocab()
|
||||
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
|
||||
|
||||
def convert_token_to_string(token: str) -> str:
|
||||
from transformers.file_utils import SPIECE_UNDERLINE
|
||||
|
||||
string = tokenizer.convert_tokens_to_string([token])
|
||||
|
||||
# A hack to handle missing spaces to HF's Llama tokenizers
|
||||
if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
|
||||
return " " + string
|
||||
|
||||
return string
|
||||
|
||||
tokenizer.convert_token_to_string = convert_token_to_string
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
class JSONLogitsProcessor(RegexLogitsProcessor):
|
||||
|
||||
def __init__(self,
|
||||
schema: Union[str, Dict, BaseModel],
|
||||
tokenizer,
|
||||
whitespace_pattern: Optional[str] = None):
|
||||
"""Compile the FSM that drives the JSON-guided generation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
schema
|
||||
A JSON schema that encodes the structure we want the model to generate
|
||||
tokenizer
|
||||
The model's tokenizer
|
||||
whitespace_pattern
|
||||
Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
|
||||
Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
|
||||
"""
|
||||
if isinstance(schema, type(BaseModel)):
|
||||
schema_str = json.dumps(schema.model_json_schema())
|
||||
elif isinstance(schema, Dict):
|
||||
schema_str = json.dumps(schema)
|
||||
elif isinstance(schema, str):
|
||||
schema_str = schema
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot parse schema {schema}. The schema must be either " +
|
||||
"a Pydantic object, a dictionary or a string that contains the JSON "
|
||||
+ "Schema specification")
|
||||
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
|
||||
super().__init__(regex_string, tokenizer)
|
||||
54
vllm/model_executor/input_metadata.py
Normal file
54
vllm/model_executor/input_metadata.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class InputMetadata:
|
||||
"""Metadata for input sequences. Used in PagedAttention.
|
||||
|
||||
Args:
|
||||
prompt_lens: Lengths of prompts.
|
||||
slot_mapping: The address to write the new KV to of each token.
|
||||
max_context_len: The maximum context length.
|
||||
context_lens: the length of attention context for each sequence.
|
||||
block_tables: The block tables. (Seq id -> list of physical block)
|
||||
kv_cache_dtype: Data type to store kv cache.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
is_prompt: bool,
|
||||
slot_mapping: torch.Tensor,
|
||||
prompt_lens: Optional[torch.Tensor],
|
||||
max_seq_len: Optional[int],
|
||||
start_loc: Optional[torch.Tensor],
|
||||
max_context_len: Optional[int],
|
||||
context_lens: Optional[torch.Tensor],
|
||||
block_tables: Optional[torch.Tensor],
|
||||
use_cuda_graph: bool,
|
||||
kv_cache_dtype: str,
|
||||
) -> None:
|
||||
self.is_prompt = is_prompt
|
||||
self.prompt_lens = prompt_lens
|
||||
self.max_seq_len = max_seq_len
|
||||
self.start_loc = start_loc
|
||||
self.max_context_len = max_context_len
|
||||
self.slot_mapping = slot_mapping
|
||||
self.context_lens = context_lens
|
||||
self.block_tables = block_tables
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# Set during the execution of the first attention op.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
self.attn_bias = None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return ("InputMetadata("
|
||||
f"is_prompt={self.is_prompt}, "
|
||||
f"max_context_len={self.max_context_len}, "
|
||||
f"slot_mapping={self.slot_mapping}, "
|
||||
f"context_lens={self.context_lens}, "
|
||||
f"block_tables={self.block_tables}, "
|
||||
f"use_cuda_graph={self.use_cuda_graph}, "
|
||||
f"kv_cache_dtype={self.kv_cache_dtype})")
|
||||
0
vllm/model_executor/layers/__init__.py
Normal file
0
vllm/model_executor/layers/__init__.py
Normal file
BIN
vllm/model_executor/layers/__pycache__/__init__.cpython-310.pyc
Normal file
BIN
vllm/model_executor/layers/__pycache__/__init__.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/layers/__pycache__/attention.cpython-310.pyc
Normal file
BIN
vllm/model_executor/layers/__pycache__/attention.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/layers/__pycache__/layernorm.cpython-310.pyc
Normal file
BIN
vllm/model_executor/layers/__pycache__/layernorm.cpython-310.pyc
Normal file
Binary file not shown.
BIN
vllm/model_executor/layers/__pycache__/linear.cpython-310.pyc
Normal file
BIN
vllm/model_executor/layers/__pycache__/linear.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/layers/__pycache__/sampler.cpython-310.pyc
Normal file
BIN
vllm/model_executor/layers/__pycache__/sampler.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
237
vllm/model_executor/layers/activation.py
Normal file
237
vllm/model_executor/layers/activation.py
Normal file
@@ -0,0 +1,237 @@
|
||||
"""Custom activation functions."""
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.parallel_utils.utils import divide
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
|
||||
class SiluAndMul(nn.Module):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.silu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
ops.silu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class GeluAndMul(nn.Module):
|
||||
"""An activation function for GeGLU.
|
||||
|
||||
The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2.
|
||||
|
||||
Shapes:
|
||||
x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d)
|
||||
return: (batch_size, seq_len, d) or (num_tokens, d)
|
||||
"""
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
d = x.shape[-1] // 2
|
||||
return F.gelu(x[..., :d]) * x[..., d:]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1] // 2
|
||||
output_shape = (x.shape[:-1] + (d, ))
|
||||
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
||||
ops.gelu_and_mul(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class NewGELU(nn.Module):
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
c = math.sqrt(2.0 / math.pi)
|
||||
return 0.5 * x * (1.0 + torch.tanh(c *
|
||||
(x + 0.044715 * torch.pow(x, 3.0))))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_new(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class FastGELU(nn.Module):
|
||||
|
||||
def _forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 *
|
||||
(1.0 + 0.044715 * x * x)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = torch.empty_like(x)
|
||||
ops.gelu_fast(out, x)
|
||||
return out
|
||||
|
||||
|
||||
class ScaledActivation(nn.Module):
|
||||
"""An activation function with post-scale parameters.
|
||||
|
||||
This is used for some quantization methods like AWQ.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act_module: nn.Module,
|
||||
intermediate_size: int,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.act = act_module
|
||||
self.input_is_parallel = input_is_parallel
|
||||
if input_is_parallel:
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
intermediate_size_per_partition = divide(intermediate_size,
|
||||
tp_size)
|
||||
else:
|
||||
intermediate_size_per_partition = intermediate_size
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.scales = nn.Parameter(
|
||||
torch.empty(intermediate_size_per_partition, dtype=params_dtype))
|
||||
set_weight_attrs(self.scales, {"weight_loader": self.weight_loader})
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.act(x) / self.scales
|
||||
|
||||
def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
|
||||
param_data = param.data
|
||||
if self.input_is_parallel:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_size = param_data.shape[0]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(0, start_idx, shard_size)
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
||||
"gelu_fast": FastGELU(),
|
||||
"gelu_new": NewGELU(),
|
||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||
"relu": nn.ReLU(),
|
||||
}
|
||||
|
||||
|
||||
def get_act_fn(
|
||||
act_fn_name: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
intermediate_size: Optional[int] = None,
|
||||
input_is_parallel: bool = True,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
) -> nn.Module:
|
||||
"""Get an activation function by name."""
|
||||
act_fn_name = act_fn_name.lower()
|
||||
if act_fn_name not in _ACTIVATION_REGISTRY:
|
||||
raise ValueError(
|
||||
f"Activation function {act_fn_name!r} is not supported.")
|
||||
|
||||
act_fn = _ACTIVATION_REGISTRY[act_fn_name]
|
||||
if (quant_config is not None
|
||||
and act_fn_name in quant_config.get_scaled_act_names()):
|
||||
if intermediate_size is None:
|
||||
raise ValueError("intermediate_size must be specified for scaled "
|
||||
"activation functions.")
|
||||
return ScaledActivation(act_fn, intermediate_size, input_is_parallel,
|
||||
params_dtype)
|
||||
return act_fn
|
||||
|
||||
|
||||
# ↓ add for smoothquant
|
||||
class DequantSiluAndMulQuant(nn.Module):
|
||||
"""An activation function for SwiGLU.
|
||||
The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2.
|
||||
Shapes:
|
||||
x: (num_tokens, 2 * d)
|
||||
return: (num_tokens, d)
|
||||
"""
|
||||
|
||||
# TODO(Zhang Ying): use_per_token_quant
|
||||
def __init__(self,
|
||||
gate_dequant_scale: float = 1.0,
|
||||
up_dequant_scale: float = 1.0,
|
||||
quant_scale: float = 1.0,
|
||||
use_per_token_quant: bool = True) -> None:
|
||||
super().__init__()
|
||||
self.register_parameter(
|
||||
"gate_dequant_scale",
|
||||
torch.nn.Parameter(
|
||||
torch.tensor(gate_dequant_scale,dtype=torch.float32,requires_grad=False))
|
||||
)
|
||||
self.register_parameter(
|
||||
"up_dequant_scale",
|
||||
torch.nn.Parameter(
|
||||
torch.tensor(up_dequant_scale,dtype=torch.float32,requires_grad=False))
|
||||
)
|
||||
self.register_parameter(
|
||||
"quant_scale",
|
||||
torch.nn.Parameter(
|
||||
torch.tensor(quant_scale, dtype=torch.float32,requires_grad=False))
|
||||
)
|
||||
self.use_per_token_quant = use_per_token_quant
|
||||
|
||||
def _apply(self, fn):
|
||||
super()._apply(fn)
|
||||
self.gate_dequant_scale.data = self.gate_dequant_scale.cpu()
|
||||
self.up_dequant_scale.data = self.up_dequant_scale.cpu()
|
||||
self.quant_scale.data = self.quant_scale.cpu()
|
||||
return self
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.gate_dequant_scale.data = self.gate_dequant_scale.to(*args, **kwargs)
|
||||
self.gate_dequant_scale.data = self.gate_dequant_scale.to(torch.float32)
|
||||
self.up_dequant_scale.data = self.up_dequant_scale.to(*args, **kwargs)
|
||||
self.up_dequant_scale.data = self.up_dequant_scale.to(torch.float32)
|
||||
self.quant_scale.data = self.quant_scale.to(*args, **kwargs)
|
||||
self.quant_scale.data = self.quant_scale.to(torch.float32)
|
||||
return self
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
num_tokens = x.numel() // x.shape[-1]
|
||||
d = x.shape[-1] // 2
|
||||
out = torch.empty(*x.shape[:-1], d, dtype=torch.int8, device=x.device)
|
||||
if self.use_per_token_quant:
|
||||
scale = torch.empty(num_tokens,
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
# tmp is used in kernel func
|
||||
tmp = torch.empty(num_tokens,
|
||||
d,
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
ops.dequant_silu_and_mul_quant(
|
||||
out, x, self.gate_dequant_scale.item(), self.up_dequant_scale.item(),
|
||||
scale, tmp)
|
||||
return out, scale
|
||||
else:
|
||||
ops.dequant_silu_and_mul_quant(
|
||||
out, x, self.gate_dequant_scale.item(), self.up_dequant_scale.item(),
|
||||
self.quant_scale.item())
|
||||
return out
|
||||
|
||||
542
vllm/model_executor/layers/attention.py
Normal file
542
vllm/model_executor/layers/attention.py
Normal file
@@ -0,0 +1,542 @@
|
||||
"""Multi-head attention."""
|
||||
import os
|
||||
enable_infer_paged_attn = os.getenv("ENABLE_INFER_PAGED_ATTN",None)
|
||||
from typing import List, Optional
|
||||
|
||||
import importlib
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from ixformer.contrib.xformers import ops as xops
|
||||
from ixformer.contrib.xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||
LowerTriangularMaskWithTensorBias)
|
||||
|
||||
from vllm._C import ops
|
||||
from vllm._C import cache_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
## from vllm.model_executor.layers.triton_kernel.prefix_prefill import (
|
||||
## context_attention_fwd)
|
||||
from vllm.utils import is_hip
|
||||
|
||||
# _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
# # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
# _PARTITION_SIZE = 512
|
||||
_SUPPORTED_HEAD_SIZES = [64, 128, 256]
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 256
|
||||
|
||||
|
||||
class PagedAttention(nn.Module):
|
||||
"""MHA/MQA/GQA layer with PagedAttention.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Reshape and store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention using either
|
||||
xformers or the PagedAttention custom op.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.register_buffer("alibi_slopes", alibi_slopes, persistent=False)
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
|
||||
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")
|
||||
|
||||
self.use_ref_attention = self.check_use_ref_attention()
|
||||
|
||||
# TODO align vllm do not need those
|
||||
self.attn_op = xops.fmha.flash.FwOp()
|
||||
head_mapping = torch.repeat_interleave(
|
||||
torch.arange(self.num_kv_heads, dtype=torch.int32),
|
||||
self.num_queries_per_kv)
|
||||
self.register_buffer("head_mapping", head_mapping, persistent=False)
|
||||
|
||||
def check_use_ref_attention(self) -> bool:
|
||||
if not is_hip():
|
||||
return False
|
||||
# For ROCm, check whether flash attention is installed or not.
|
||||
# if not, use_ref_attention needs to be True
|
||||
return importlib.util.find_spec("flash_attn") is None
|
||||
|
||||
def ref_masked_attention(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
seq_len, _, _ = query.shape
|
||||
attn_mask = torch.triu(torch.ones(seq_len,
|
||||
seq_len,
|
||||
dtype=query.dtype,
|
||||
device=query.device),
|
||||
diagonal=1)
|
||||
attn_mask = attn_mask * torch.finfo(query.dtype).min
|
||||
|
||||
attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query,
|
||||
key).float()
|
||||
attn_weights = attn_weights + attn_mask.float()
|
||||
attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype)
|
||||
out = torch.einsum("hqk,khd->qhd", attn_weights, value)
|
||||
return out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
"""PagedAttention forward pass.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for the inputs.
|
||||
cache_event: event to wait for the cache operations to finish.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
slot_mapping = input_metadata.slot_mapping
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
# If key_cache and value_cache are not provided, the new key and value
|
||||
# vectors will not be cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
if key_cache is not None and value_cache is not None:
|
||||
cache_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping,
|
||||
)
|
||||
|
||||
if input_metadata.is_prompt:
|
||||
# normal attention
|
||||
if (key_cache is None or value_cache is None
|
||||
or input_metadata.block_tables.numel() == 0):
|
||||
if input_metadata.attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(input_metadata.prompt_lens)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
else:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(input_metadata.prompt_lens)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
|
||||
if self.use_ref_attention:
|
||||
output = self.ref_masked_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
)
|
||||
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
|
||||
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
|
||||
return output.reshape(num_tokens, hidden_size)
|
||||
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=self.attn_op,
|
||||
alibi_slopes=self.alibi_slopes
|
||||
)
|
||||
output = out.view_as(query)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output = torch.empty_like(query)
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.block_tables, # [BS, max_block_per_request]
|
||||
input_metadata.start_loc,
|
||||
input_metadata.prompt_lens,
|
||||
input_metadata.context_lens,
|
||||
input_metadata.max_seq_len,
|
||||
getattr(self, "alibi_slopes", None),
|
||||
)
|
||||
else:
|
||||
# Decoding run.
|
||||
output = _paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
self.head_mapping, # self.num_kv_heads
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(num_tokens, hidden_size)
|
||||
# TODO align
|
||||
"""
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
PagedAttention forward pass.
|
||||
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
block_size]
|
||||
input_metadata: metadata for the inputs.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
# Reshape the keys and values and store them in the cache.
|
||||
# If key_cache and value_cache are not provided, the new key and value
|
||||
# vectors will not be cached. This happens during the initial memory
|
||||
# profiling run.
|
||||
if key_cache is not None and value_cache is not None:
|
||||
cache_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.slot_mapping.flatten(),
|
||||
input_metadata.kv_cache_dtype,
|
||||
)
|
||||
|
||||
if input_metadata.is_prompt:
|
||||
# normal attention
|
||||
if (key_cache is None or value_cache is None
|
||||
or input_metadata.block_tables.numel() == 0):
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
|
||||
# project the key and value tensors to the desired number of
|
||||
# heads.
|
||||
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
|
||||
query = query.view(query.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
query.shape[-1])
|
||||
key = key[:, :,
|
||||
None, :].expand(key.shape[0], self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
key.shape[-1])
|
||||
value = value[:, :,
|
||||
None, :].expand(value.shape[0],
|
||||
self.num_kv_heads,
|
||||
self.num_queries_per_kv,
|
||||
value.shape[-1])
|
||||
|
||||
# Set attention bias if not provided. This typically happens at
|
||||
# the very attention layer of every iteration.
|
||||
# FIXME(woosuk): This is a hack.
|
||||
if input_metadata.attn_bias is None:
|
||||
if self.alibi_slopes is None:
|
||||
attn_bias = BlockDiagonalCausalMask.from_seqlens(
|
||||
[seq_len] * batch_size)
|
||||
if self.sliding_window is not None:
|
||||
attn_bias = attn_bias.make_local_attention(
|
||||
self.sliding_window)
|
||||
input_metadata.attn_bias = attn_bias
|
||||
else:
|
||||
input_metadata.attn_bias = _make_alibi_bias(
|
||||
self.alibi_slopes, self.num_kv_heads, batch_size,
|
||||
seq_len, query.dtype)
|
||||
|
||||
if self.use_ref_attention:
|
||||
output = self.ref_masked_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
)
|
||||
# Using view got RuntimeError: view size is not compatible with input tensor's size and stride
|
||||
# (at least one dimension spans across two contiguous subspaces). Use reshape instead
|
||||
return output.reshape(batch_size, seq_len, hidden_size)
|
||||
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce
|
||||
# them in the future for code readability.
|
||||
if self.alibi_slopes is None:
|
||||
query = query.unsqueeze(0)
|
||||
key = key.unsqueeze(0)
|
||||
value = value.unsqueeze(0)
|
||||
else:
|
||||
query = query.unflatten(0, (batch_size, seq_len))
|
||||
key = key.unflatten(0, (batch_size, seq_len))
|
||||
value = value.unflatten(0, (batch_size, seq_len))
|
||||
|
||||
out = xops.memory_efficient_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
|
||||
(is_hip()) else None,
|
||||
)
|
||||
output = out.view_as(query)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
output = torch.empty_like(query)
|
||||
context_attention_fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata.block_tables, # [BS, max_block_per_request]
|
||||
input_metadata.start_loc,
|
||||
input_metadata.prompt_lens,
|
||||
input_metadata.context_lens,
|
||||
input_metadata.max_seq_len,
|
||||
getattr(self, "alibi_slopes", None),
|
||||
)
|
||||
|
||||
else:
|
||||
# Decoding run.
|
||||
output = _paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(batch_size, seq_len, hidden_size)
|
||||
"""
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
) -> LowerTriangularMaskWithTensorBias:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
# When using custom attention bias, xformers requires the bias to
|
||||
# be sliced from a tensor whose length is a multiple of 8.
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = torch.empty(
|
||||
batch_size,
|
||||
num_heads,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if num_heads != num_kv_heads:
|
||||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||
return attn_bias
|
||||
|
||||
|
||||
def _paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
head_mapping: torch.Tensor, # num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
use_sqrt_alibi: bool = False
|
||||
) -> torch.Tensor:
|
||||
output = torch.empty_like(query)
|
||||
|
||||
use_v2 = enable_infer_paged_attn is None and key_cache.dim() == 4
|
||||
if not use_v2:
|
||||
block_size = value_cache.shape[3]
|
||||
# Run PagedAttention V1.
|
||||
ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping, # num_kv_heads
|
||||
scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
input_metadata.kv_cache_dtype,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
block_size = value_cache.shape[2]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = (
|
||||
(input_metadata.max_context_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping, # num_kv_heads
|
||||
scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
block_size,
|
||||
input_metadata.max_context_len,
|
||||
alibi_slopes,
|
||||
input_metadata.kv_cache_dtype,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
# ↓ add for smoothquant
|
||||
class DequantPagedAttention(PagedAttention):
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
quant_kv_cache: bool = False,
|
||||
kv_quant_params: torch.Tensor = None,
|
||||
quant_scale: float = 1.0,
|
||||
use_per_token_quant: bool = True,
|
||||
) -> None:
|
||||
super().__init__(num_heads,
|
||||
head_size,
|
||||
scale,
|
||||
num_kv_heads,
|
||||
alibi_slopes,
|
||||
sliding_window)
|
||||
self.register_parameter(
|
||||
"quant_scale",
|
||||
torch.nn.Parameter(
|
||||
torch.tensor(quant_scale, dtype=torch.float32,requires_grad=False))
|
||||
)
|
||||
self.use_per_token_quant = use_per_token_quant
|
||||
|
||||
def _apply(self, fn):
|
||||
super()._apply(fn)
|
||||
self.quant_scale.data = self.quant_scale.cpu()
|
||||
return self
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
super().to(*args, **kwargs)
|
||||
self.quant_scale.data = self.quant_scale.to(*args, **kwargs)
|
||||
self.quant_scale.data = self.quant_scale.to(torch.float32)
|
||||
return self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: Optional[torch.Tensor],
|
||||
value_cache: Optional[torch.Tensor],
|
||||
input_metadata: InputMetadata,
|
||||
) -> torch.Tensor:
|
||||
out = super().forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
)
|
||||
quant_out = torch.empty_like(out, dtype=torch.int8)
|
||||
if self.use_per_token_quant:
|
||||
scale = torch.empty(out.numel() // out.shape[-1],
|
||||
dtype=torch.float32,
|
||||
device=out.device)
|
||||
ops.quant(quant_out, out, scale)
|
||||
return quant_out, scale
|
||||
else:
|
||||
ops.quant(quant_out, out, self.quant_scale.item())
|
||||
return (quant_out, )
|
||||
5
vllm/model_executor/layers/fused_moe/__init__.py
Normal file
5
vllm/model_executor/layers/fused_moe/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe
|
||||
|
||||
__all__ = [
|
||||
"fused_moe",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user