commit ae2c299b3a41c3af0df20733122cbc85d94dfb32 Author: wangjing Date: Thu Aug 7 07:25:16 2025 +0000 init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..800d003 --- /dev/null +++ b/.gitignore @@ -0,0 +1,205 @@ +# version file generated by setuptools-scm +/vllm/_version.py + +# vllm-flash-attn built from source +vllm/vllm_flash_attn/* + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +cmake-build-*/ +CMakeUserPresets.json +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +/.deps/ + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# generated files +**/generated/** + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site +docs/argparse +docs/examples + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# VSCode +.vscode/ + +# DS Store +.DS_Store + +# Results +*.csv + +# Python pickle files +*.pkl + +# Sphinx documentation +_build/ + +# vim swap files +*.swo +*.swp + +# hip files generated by PyTorch +*.hip +*_hip* +hip_compat.h + +# Benchmark dataset +benchmarks/**/*.json + +# Linting +actionlint +shellcheck*/ + +# Ignore moe/marlin_moe gen code +csrc/moe/marlin_moe_wna16/kernel_* \ No newline at end of file diff --git a/vllm/_C.py b/vllm/_C.py new file mode 100644 index 0000000..0980701 --- /dev/null +++ b/vllm/_C.py @@ -0,0 +1,266 @@ +from typing import Optional + +import torch +import torch.nn.functional as F + +import ixformer +import ixformer.functions as ixf_F +from ixformer._C import ReduceOp +from ixformer._C import _distributed as cdist +from ixformer._C._distributed import is_initialized, get_default_comm_group +from ixformer.contrib.torch.extension import ixformer_torch as ixft +from ixformer.contrib.torch.data_type_mapping import torch_to_ixformer_dtype + + +class ops(): + # activations + @staticmethod + def silu_and_mul(output, x): + ixf_F.silu_and_mul(x, output) + + @staticmethod + def gelu_and_mul(output, x): + ixf_F.gelu_and_mul(x, output) + + @staticmethod + def gelu_new(output, x): + return F.gelu(x,"tanh") + + @staticmethod + def gelu_fast(output, x): + return F.gelu(x,"tanh") + + # rms norm + @staticmethod + def rms_norm(output, x, weight, epsilon): + ixf_F.rms_norm(x, weight, output, epsilon) + + @staticmethod + def fused_add_rms_norm(input, residual, weight, epsilon, scale): + ixf_F.fused_add_rms_norm(input, residual, weight, epsilon, scale) + + # rotary embedding + @staticmethod + def rotary_embedding(positions, query, key, head_size, + cos_sin_cache, is_neox_style): + ixf_F.vllm_rotary_embedding_neox(positions, query, key, head_size, + cos_sin_cache, is_neox_style) + + # paged attention + @staticmethod + def paged_attention_v1( + output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes=None, + kv_cache_dtype=None, + ): + return ixf_F.vllm_single_query_cached_kv_attention( + output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + ) + + @staticmethod + def paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes=None, + kv_cache_dtype=None, + use_sqrt_alibi=False, + ): + return ixf_F.vllm_single_query_cached_kv_attention_v2( + output, + 256, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + head_mapping, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + use_sqrt_alibi, + ) + + # awq + @staticmethod + def awq_gemm(x, qweight, scales, qzeros, pack_factor): + return ixf_F.quantized_linear(x,qweight,scales,"awq",32 // pack_factor,qzeros,None,group_size=128) + + @staticmethod + def awq_dequantize(qweight, scales, qzeros, holder1, holder2, holder3): + raise NotImplementedError() + + # gqt-q + @staticmethod + def gptq_shuffle(qweights,g_idx,weight_bits): + return ixf_F.vllm_gptq_shuffle(qweights,g_idx) + + @staticmethod + def gptq_gemm(x, qweight, qzeros, scales, idx, status, weight_bits): + batch = x.shape[0] + if batch <= 8: + return ixf_F.quantized_linear(x,qweight,scales,"gptq",4,qzeros,None,group_size=128) + o_dtype_str = "fp16" if x.dtype == torch.half else "bf16" + deq_w = ixf_F.quantized_weight_dequant(qweight,scales,"gptq",o_dtype_str,4,qzeros,group_size=128) + return torch.matmul(x,deq_w) + + # squeezellm + @staticmethod + def squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table): + raise NotImplementedError() + + # marlin + @staticmethod + def marlin_gemm(x_2d, qweight, scales, workspace, size_m, size_n, size_k): + raise NotImplementedError() + + # moe + @staticmethod + def moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad): + raise NotImplementedError() + + # smoothquant + @staticmethod + def quant(output,input,scale): + ixf_F.vllm_smooth_quant(output,input,scale) + return output + + @staticmethod + def dequant(output,x,scale,global_scale): + ixf_F.vllm_smooth_dequant(output,x,scale,global_scale) + return output + + @staticmethod + def dequant_add_residual(output,x,residual,scale,global_scale): + if isinstance(x,torch.Tensor): + ixf_F.vllm_smooth_dequant_add_residual(output,x,residual,scale,global_scale) + return output + + @staticmethod + def dequant_silu_and_mul_quant(output,x,gate_scale, up_scale, scale, temp = None): + ixf_F.vllm_smooth_dequant_silu_and_mul_quant(output,x,gate_scale, up_scale, scale, temp) + + @staticmethod + def rms_norm_quant(output, input, weight, epsilon): + return ixf_F.vllm_smooth_rms_norm_quant(output, input, weight, epsilon) + + @staticmethod + def fused_add_rms_norm_quant(output, input, residual, weight, epsilon): + ixf_F.vllm_smooth_fused_add_rms_norm_quant(output, input, residual, weight, epsilon) + + @staticmethod + def dequant_fused_add_rms_norm_quant(output, input, residual, weight, epsilon, scale, global_scale): + ixf_F.vllm_smooth_dequant_fused_add_rms_norm_quant(output, input, residual, weight, epsilon, scale, global_scale) + + @staticmethod + def dequant_rotary_embedding(positions, query, key, head_size, + cos_sin_cache, query_out, key_out, query_scale, key_scale, is_neox_style): + ixf_F.vllm_smooth_dequant_rotary_embedding_neox(positions, query, key, head_size, + cos_sin_cache, query_out, key_out, query_scale, key_scale, is_neox_style) + + @staticmethod + def linear_a8_w8_o32_(x, weight, output): + return ixf_F.linear_i8w8o32(x,weight,output) + + +class cache_ops(): + + @staticmethod + def reshape_and_cache(key, value, key_cache, value_cache, slot_mapping): + ixf_F.vllm_cache_ops_reshape_and_cache( + key, value, key_cache, value_cache, slot_mapping + ) + + @staticmethod + def copy_blocks(key_caches, value_caches, block_mapping): + ixf_F.vllm_copy_cache( + key_caches, value_caches, block_mapping + ) + + @staticmethod + def swap_blocks(src_key_cache, dst_key_cache, src_to_dst): + ixf_F.vllm_swap_blocks( + src_key_cache, dst_key_cache, src_to_dst + ) + +class custom_ar(): + + IS_INIT:bool = False + + @staticmethod + def is_init(): + return_status = custom_ar.IS_INIT + custom_ar.IS_INIT = True + return return_status + + @staticmethod + def init_cumtom_ar(): + if not is_initialized(get_default_comm_group()): + group = ixft.create_ixformer_group_from_pg() + ixformer.cuda.set_device(torch.cuda.current_device()) + cdist.update_default_comm_group(group) + cdist.ipc.init_communicator_by_nccl() + + @staticmethod + def all_reduce_reg(ptr,tensor,out = None): + raise NotImplementedError() + + @staticmethod + def all_reduce_unreg(ptr,tensor,buffer,out = None): + dtype = tensor.dtype + if torch.is_tensor(tensor): + dtype = torch_to_ixformer_dtype(dtype) + + if out is None: + out = tensor + cdist.ipc.allreduce( + tensor.data_ptr(), out.data_ptr(), dtype, tensor.numel(), ReduceOp.SUM + ) + return out + + @staticmethod + def dispose(): + ixformer.distributed.destroy_process_group() + + @staticmethod + def should_custom_ar(tensor:torch.Tensor, max_size, world_size, full_nvlink): + return cdist.ipc.should_custom_ar(tensor.numel(),tensor.element_size(),max_size,world_size) + +class cuda_utils(): + @staticmethod + def get_max_shared_memory_per_block_device_attribute(gpu): + return 100000000 diff --git a/vllm/__init__.py b/vllm/__init__.py new file mode 100644 index 0000000..96f1f26 --- /dev/null +++ b/vllm/__init__.py @@ -0,0 +1,29 @@ +"""vLLM: a high-throughput and memory-efficient inference engine for LLMs""" +import os + +# By default, to avoid memory fragmentation, disable UMD mempool +if os.getenv("UMD_ENABLEMEMPOOL") is None: + os.environ["UMD_ENABLEMEMPOOL"] = "0" +os.environ["NCCL_FORCESYNC_DISABLE"] = "1" + +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.engine.llm_engine import LLMEngine +from vllm.engine.ray_utils import initialize_cluster +from vllm.entrypoints.llm import LLM +from vllm.outputs import CompletionOutput, RequestOutput +from vllm.sampling_params import SamplingParams + +__version__ = "0.3.3" + +__all__ = [ + "LLM", + "SamplingParams", + "RequestOutput", + "CompletionOutput", + "LLMEngine", + "EngineArgs", + "AsyncLLMEngine", + "AsyncEngineArgs", + "initialize_cluster", +] diff --git a/vllm/_moe_C.py b/vllm/_moe_C.py new file mode 100644 index 0000000..f974826 --- /dev/null +++ b/vllm/_moe_C.py @@ -0,0 +1,5 @@ +import torch +import ixformer.functions as ixf_F + +def topk_softmax(topk_weights,topk_ids,token_expert_indicies,gating_output): + raise NotImplementedError() diff --git a/vllm/block.py b/vllm/block.py new file mode 100644 index 0000000..5fe39ed --- /dev/null +++ b/vllm/block.py @@ -0,0 +1,72 @@ +"""Token blocks.""" +from typing import List + +from vllm.utils import Device + +_BLANK_TOKEN_ID = -1 + + +class LogicalTokenBlock: + """A block that stores a contiguous chunk of tokens from left to right. + + Logical blocks are used to represent the states of the corresponding + physical blocks in the KV cache. + """ + + def __init__( + self, + block_number: int, + block_size: int, + ) -> None: + self.block_number = block_number + self.block_size = block_size + + self.token_ids = [_BLANK_TOKEN_ID] * block_size + self.num_tokens = 0 + + def is_empty(self) -> bool: + return self.num_tokens == 0 + + def get_num_empty_slots(self) -> int: + return self.block_size - self.num_tokens + + def is_full(self) -> bool: + return self.num_tokens == self.block_size + + def append_tokens(self, token_ids: List[int]) -> None: + assert len(token_ids) <= self.get_num_empty_slots() + curr_idx = self.num_tokens + self.token_ids[curr_idx:curr_idx + len(token_ids)] = token_ids + self.num_tokens += len(token_ids) + + def get_token_ids(self) -> List[int]: + return self.token_ids[:self.num_tokens] + + def get_last_token_id(self) -> int: + assert self.num_tokens > 0 + return self.token_ids[self.num_tokens - 1] + + +class PhysicalTokenBlock: + """Represents the state of a block in the KV cache.""" + + def __init__( + self, + device: Device, + block_number: int, + block_size: int, + ) -> None: + self.device = device + self.block_number = block_number + self.block_size = block_size + + self.ref_count = 0 + + def __repr__(self) -> str: + return (f'PhysicalTokenBlock(device={self.device}, ' + f'block_number={self.block_number}, ' + f'ref_count={self.ref_count})') + + +# Mapping: logical block number -> physical block. +BlockTable = List[PhysicalTokenBlock] diff --git a/vllm/config.py b/vllm/config.py new file mode 100644 index 0000000..9b63f55 --- /dev/null +++ b/vllm/config.py @@ -0,0 +1,689 @@ +from typing import Optional, Union, ClassVar +from dataclasses import dataclass +import os +from packaging.version import Version + +import torch +from transformers import PretrainedConfig + +from vllm.logger import init_logger +from vllm.transformers_utils.config import get_config +from vllm.utils import get_cpu_memory, is_hip, is_neuron, get_nvcc_cuda_version + +logger = init_logger(__name__) + +_GB = 1 << 30 + + +class ModelConfig: + """Configuration for the model. + + Args: + model: Name or path of the huggingface model to use. + tokenizer: Name or path of the huggingface tokenizer to use. + tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if + available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + dtype: Data type for model weights and activations. The "auto" option + will use FP16 precision for FP32 and FP16 models, and BF16 precision + for BF16 models. + seed: Random seed for reproducibility. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. If unspecified, will use the default + version. + code_revision: The specific revision to use for the model code on + Hugging Face Hub. It can be a branch name, a tag name, or a + commit id. If unspecified, will use the default version. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. If unspecified, will use + the default version. + max_model_len: Maximum length of a sequence (including prompt and + output). If None, will be derived from the model. + quantization: Quantization method that was used to quantize the model + weights. If None, we assume the model weights are not quantized. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. + """ + + def __init__( + self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + download_dir: Optional[str], + load_format: str, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + enforce_eager: bool = False, + max_context_len_to_capture: Optional[int] = None, + ) -> None: + self.model = model + self.tokenizer = tokenizer + self.tokenizer_mode = tokenizer_mode + self.trust_remote_code = trust_remote_code + self.download_dir = download_dir + self.load_format = load_format + self.seed = seed + self.revision = revision + self.code_revision = code_revision + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.enforce_eager = True + # TODO align + # Use graph cause a runtime error, for now, do not use cuda graph + self.max_context_len_to_capture = max_context_len_to_capture + + if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C + if not os.path.exists(model): + model_path = snapshot_download(model_id=model, + cache_dir=download_dir, + revision=revision) + else: + model_path = model + self.model = model_path + self.download_dir = model_path + self.tokenizer = model_path + + self.hf_config = get_config(self.model, trust_remote_code, revision, + code_revision) + self.dtype = _get_and_verify_dtype(self.hf_config, dtype) + self.max_model_len = _get_and_verify_max_len(self.hf_config, + max_model_len) + self._verify_load_format() + self._verify_tokenizer_mode() + self._verify_quantization() + self._verify_cuda_graph() + + def _verify_load_format(self) -> None: + load_format = self.load_format.lower() + supported_load_format = [ + "auto", "pt", "safetensors", "npcache", "dummy" + ] + rocm_not_supported_load_format = [] + if load_format not in supported_load_format: + raise ValueError( + f"Unknown load format: {self.load_format}. Must be one of " + "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in supported_load_format + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format \'{load_format}\' is not supported in ROCm. " + f"Supported load format are " + f"{rocm_supported_load_format}") + + # TODO: Remove this check once HF updates the pt weights of Mixtral. + architectures = getattr(self.hf_config, "architectures", []) + if "MixtralForCausalLM" in architectures and load_format == "pt": + raise ValueError( + "Currently, the 'pt' format is not supported for Mixtral. " + "Please use the 'safetensors' format instead. ") + self.load_format = load_format + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = self.tokenizer_mode.lower() + if tokenizer_mode not in ["auto", "slow"]: + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + "either 'auto' or 'slow'.") + self.tokenizer_mode = tokenizer_mode + + def _verify_quantization(self) -> None: + supported_quantization = ["awq", "gptq", "squeezellm", "marlin", "smoothquant"] + rocm_not_supported_quantization = ["awq", "marlin"] + if self.quantization is not None: + self.quantization = self.quantization.lower() + + # Parse quantization method from the HF model config, if available. + hf_quant_config = getattr(self.hf_config, "quantization_config", None) + if hf_quant_config is not None: + + hf_quant_method = str(hf_quant_config["quant_method"]).lower() + # If the GPTQ model is serialized in marlin format, use marlin. + if (hf_quant_method == "gptq" + and "is_marlin_format" in hf_quant_config + and hf_quant_config["is_marlin_format"]): + hf_quant_method = "marlin" + if self.quantization is None: + self.quantization = hf_quant_method + elif self.quantization != hf_quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({hf_quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization}).") + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}.") + if is_hip( + ) and self.quantization in rocm_not_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not supported " + f"in ROCm.") + if self.quantization != "marlin": + logger.warning( + f"{self.quantization} quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.") + + def _verify_cuda_graph(self) -> None: + if self.max_context_len_to_capture is None: + self.max_context_len_to_capture = self.max_model_len + self.max_context_len_to_capture = min(self.max_context_len_to_capture, + self.max_model_len) + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_num_attention_heads = self.hf_config.num_attention_heads + tensor_parallel_size = parallel_config.tensor_parallel_size + if total_num_attention_heads % tensor_parallel_size != 0: + raise ValueError( + f"Total number of attention heads ({total_num_attention_heads})" + " must be divisible by tensor parallel size " + f"({tensor_parallel_size}).") + + total_num_hidden_layers = self.hf_config.num_hidden_layers + pipeline_parallel_size = parallel_config.pipeline_parallel_size + if total_num_hidden_layers % pipeline_parallel_size != 0: + raise ValueError( + f"Total number of hidden layers ({total_num_hidden_layers}) " + "must be divisible by pipeline parallel size " + f"({pipeline_parallel_size}).") + + def get_sliding_window(self) -> Optional[int]: + return getattr(self.hf_config, "sliding_window", None) + + def get_vocab_size(self) -> int: + return self.hf_config.vocab_size + + def get_hidden_size(self) -> int: + return self.hf_config.hidden_size + + def get_head_size(self) -> int: + if hasattr(self.hf_config, "head_dim"): + return self.hf_config.head_dim + # FIXME(woosuk): This may not be true for all models. + return self.hf_config.hidden_size // self.hf_config.num_attention_heads + + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False)) + if not new_decoder_arch_falcon and getattr(self.hf_config, + "multi_query", False): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_config.num_attention_heads + + def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: + """Returns the number of KV heads per GPU.""" + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, + total_num_kv_heads // parallel_config.tensor_parallel_size) + + def get_num_layers(self, parallel_config: "ParallelConfig") -> int: + total_num_hidden_layers = self.hf_config.num_hidden_layers + return total_num_hidden_layers // parallel_config.pipeline_parallel_size + + +class CacheConfig: + """Configuration for the KV cache. + + Args: + block_size: Size of a cache block in number of tokens. + gpu_memory_utilization: Fraction of GPU memory to use for the + vLLM execution. + swap_space: Size of the CPU swap space per GPU (in GiB). + cache_dtype: Data type for kv cache storage. + """ + + def __init__( + self, + block_size: int, + gpu_memory_utilization: float, + swap_space: int, + cache_dtype: str, + sliding_window: Optional[int] = None, + ) -> None: + self.block_size = block_size + self.gpu_memory_utilization = gpu_memory_utilization + self.swap_space_bytes = swap_space * _GB + self.cache_dtype = cache_dtype + self.sliding_window = sliding_window + self._verify_args() + self._verify_cache_dtype() + + # Will be set after profiling. + self.num_gpu_blocks = None + self.num_cpu_blocks = None + + def metrics_info(self): + # convert cache_config to dict(key: str, value:str) for prometheus metrics info + return {key: str(value) for key, value in self.__dict__.items()} + + def _verify_args(self) -> None: + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype == "fp8_e5m2": + nvcc_cuda_version = get_nvcc_cuda_version() + if nvcc_cuda_version and nvcc_cuda_version < Version("11.8"): + raise ValueError( + "FP8 is not supported when cuda version is lower than 11.8." + ) + device_name = torch.cuda.get_device_name() + if "AMD" in device_name: + raise NotImplementedError( + "FP8_E5M2 KV Cache on AMD GPU has not been supported yet.") + logger.info( + "Using fp8_e5m2 data type to store kv cache. It reduces " + "the GPU memory footprint and boosts the performance. " + "But it may cause slight accuracy drop. " + "Currently we only support fp8 without scaling factors and " + "make e5m2 as a default format.") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + + def verify_with_parallel_config( + self, + parallel_config: "ParallelConfig", + ) -> None: + total_cpu_memory = get_cpu_memory() + # FIXME(woosuk): Here, it is assumed that the GPUs in a tensor parallel + # group are in the same node. However, the GPUs may span multiple nodes. + num_gpus_per_node = parallel_config.tensor_parallel_size + cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node + + msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of " + f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is " + "allocated for the swap space.") + if cpu_memory_usage > 0.7 * total_cpu_memory: + raise ValueError("Too large swap space. " + msg) + elif cpu_memory_usage > 0.4 * total_cpu_memory: + logger.warning("Possibly too large swap space. " + msg) + + +class ParallelConfig: + """Configuration for the distributed execution. + + Args: + pipeline_parallel_size: Number of pipeline parallel groups. + tensor_parallel_size: Number of tensor parallel groups. + worker_use_ray: Whether to use Ray for model workers. Will be set to + True if either pipeline_parallel_size or tensor_parallel_size is + greater than 1. + max_parallel_loading_workers: Maximum number of multiple batches + when load model sequentially. To avoid RAM OOM when using tensor + parallel and large models. + disable_custom_all_reduce: Disable the custom all-reduce kernel and + fall back to NCCL. + """ + + def __init__( + self, + pipeline_parallel_size: int, + tensor_parallel_size: int, + worker_use_ray: bool, + max_parallel_loading_workers: Optional[int] = None, + disable_custom_all_reduce: bool = False, + ) -> None: + self.pipeline_parallel_size = pipeline_parallel_size + if is_neuron(): + # For Neuron device support, here we assign TP=1 to avoid sharding within vLLM directly. + # Transformer-neuronx would take neuron_tp_degree attribute, and distribute the workload + # to multiple NeuronCores. + self.tensor_parallel_size = 1 + self.neuron_tp_degree = tensor_parallel_size + else: + self.tensor_parallel_size = tensor_parallel_size + self.worker_use_ray = worker_use_ray + self.max_parallel_loading_workers = max_parallel_loading_workers + self.disable_custom_all_reduce = disable_custom_all_reduce + + self.world_size = pipeline_parallel_size * self.tensor_parallel_size + # Ray worker is not supported for Neuron backend. + if self.world_size > 1 and not is_neuron(): + self.worker_use_ray = True + self._verify_args() + + def _verify_args(self) -> None: + if self.pipeline_parallel_size > 1: + raise NotImplementedError( + "Pipeline parallelism is not supported yet.") + if not self.disable_custom_all_reduce and self.world_size > 1: + if is_hip(): + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "supported on AMD GPUs.") + elif self.pipeline_parallel_size > 1: + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "supported with pipeline parallelism.") + + # FIXME(woosuk): Fix the stability issues and re-enable the custom + # all-reduce kernel. + if not self.disable_custom_all_reduce and self.world_size > 1: + self.disable_custom_all_reduce = True + logger.info( + "Custom all-reduce kernels are temporarily disabled due to " + "stability issues. We will re-enable them once the issues are " + "resolved.") + + +class SchedulerConfig: + """Scheduler configuration. + + Args: + max_num_batched_tokens: Maximum number of tokens to be processed in + a single iteration. + max_num_seqs: Maximum number of sequences to be processed in a single + iteration. + max_model_len: Maximum length of a sequence (including prompt + and generated text). + max_paddings: Maximum number of paddings to be added to a batch. + """ + + def __init__( + self, + max_num_batched_tokens: Optional[int], + max_num_seqs: int, + max_model_len: int, + max_paddings: int, + ) -> None: + if max_num_batched_tokens is not None: + self.max_num_batched_tokens = max_num_batched_tokens + else: + # If max_model_len is too short, use 2048 as the default value for + # higher throughput. + self.max_num_batched_tokens = max(max_model_len, 2048) + self.max_num_seqs = max_num_seqs + self.max_model_len = max_model_len + self.max_paddings = max_paddings + self._verify_args() + + def _verify_args(self) -> None: + if self.max_num_batched_tokens < self.max_model_len: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) is " + f"smaller than max_model_len ({self.max_model_len}). " + "This effectively limits the maximum sequence length to " + "max_num_batched_tokens and makes vLLM reject longer " + "sequences. Please increase max_num_batched_tokens or " + "decrease max_model_len.") + if self.max_num_batched_tokens < self.max_num_seqs: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_num_seqs " + f"({self.max_num_seqs}).") + + +class DeviceConfig: + + def __init__(self, device: str = "auto") -> None: + if device == "auto": + # Automated device type detection + if torch.cuda.is_available(): + self.device_type = "cuda" + elif is_neuron(): + self.device_type = "neuron" + else: + raise RuntimeError("No supported device detected.") + else: + # Device type is assigned explicitly + self.device_type = device + + # Some device types require processing inputs on CPU + if self.device_type in ["neuron"]: + self.device = torch.device("cpu") + else: + # Set device with device type + self.device = torch.device(self.device_type) + + @property + def is_neuron(self): + return self.device_type == "neuron" + + +@dataclass +class LoRAConfig: + max_lora_rank: int + max_loras: int + max_cpu_loras: Optional[int] = None + lora_dtype: Optional[torch.dtype] = None + lora_extra_vocab_size: int = 256 + # This is a constant. + lora_vocab_padding_size: ClassVar[int] = 256 + + def __post_init__(self): + # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + possible_max_ranks = (8, 16, 32, 64) + possible_lora_extra_vocab_size = (0, 256, 512) + if self.max_lora_rank not in possible_max_ranks: + raise ValueError( + f"max_lora_rank ({self.max_lora_rank}) must be one of " + f"{possible_max_ranks}.") + if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: + raise ValueError( + f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " + f"must be one of {possible_lora_extra_vocab_size}.") + if self.max_loras < 1: + raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_loras ({self.max_loras})") + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) + if model_config.quantization is not None: + raise ValueError( + "LoRA is not supported with quantized models yet.") + + def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): + if scheduler_config.max_num_batched_tokens > 65528: + raise ValueError( + "Due to limitations of the custom LoRA CUDA kernel, " + "max_num_batched_tokens must be <= 65528 when " + "LoRA is enabled.") + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + +_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"] + + +def _get_and_verify_dtype( + config: PretrainedConfig, + dtype: Union[str, torch.dtype], +) -> torch.dtype: + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + if config_dtype is None: + config_dtype = torch.float32 + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + # Following the common practice, we use float16 for float32 + # models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + if is_hip() and torch_dtype == torch.float32: + rocm_supported_dtypes = [ + k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() + if (k not in _ROCM_NOT_SUPPORTED_DTYPE) + ] + raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. " + f"Supported dtypes are {rocm_supported_dtypes}") + + # Verify the dtype. + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + pass + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + pass + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning(f"Casting {config_dtype} to {torch_dtype}.") + + return torch_dtype + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + max_model_len: Optional[int], +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Others + "model_max_length", + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + for key in possible_keys: + max_len_key = getattr(hf_config, key, None) + if max_len_key is not None: + derived_max_model_len = min(derived_max_model_len, max_len_key) + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + f"{possible_keys}. Assuming the model's maximum length is " + f"{default_max_len}.") + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling is not None: + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if "type" in rope_scaling: + rope_type = rope_scaling["type"] + elif "rope_type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + else: + raise ValueError( + "rope_scaling must have a 'type' or 'rope_type' key.") + + if rope_type == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + if max_model_len is None: + max_model_len = derived_max_model_len + elif max_model_len > derived_max_model_len: + raise ValueError( + f"User-specified max_model_len ({max_model_len}) is greater than " + f"the derived max_model_len ({max_len_key}={derived_max_model_len}" + " in model's config.json). This may lead to incorrect model " + "outputs or CUDA errors. Make sure the value is correct and " + "within the model context size.") + return int(max_model_len) diff --git a/vllm/core/__init__.py b/vllm/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py new file mode 100644 index 0000000..3946096 --- /dev/null +++ b/vllm/core/block_manager.py @@ -0,0 +1,330 @@ +"""A block manager that manages token blocks.""" +import enum +from typing import Dict, List, Optional, Set, Tuple + +from vllm.block import BlockTable, PhysicalTokenBlock +from vllm.sequence import Sequence, SequenceGroup, SequenceStatus +from vllm.utils import Device + + +class BlockAllocator: + """Manages free physical token blocks for a device. + + The allocator maintains a list of free blocks and allocates a block when + requested. When a block is freed, its reference count is decremented. If + the reference count becomes zero, the block is added back to the free list. + """ + + def __init__( + self, + device: Device, + block_size: int, + num_blocks: int, + ) -> None: + self.device = device + self.block_size = block_size + self.num_blocks = num_blocks + + # Initialize the free blocks. + self.free_blocks: BlockTable = [] + for i in range(num_blocks): + block = PhysicalTokenBlock(device=device, + block_number=i, + block_size=block_size) + self.free_blocks.append(block) + + def allocate(self) -> PhysicalTokenBlock: + if not self.free_blocks: + raise ValueError("Out of memory! No free blocks are available.") + block = self.free_blocks.pop() + block.ref_count = 1 + return block + + def free(self, block: PhysicalTokenBlock) -> None: + if block.ref_count == 0: + raise ValueError(f"Double free! {block} is already freed.") + block.ref_count -= 1 + if block.ref_count == 0: + self.free_blocks.append(block) + + def get_num_free_blocks(self) -> int: + return len(self.free_blocks) + + +class AllocStatus(enum.Enum): + """Result for BlockSpaceManager.can_allocate + + 1. Ok: seq_group can be allocated now. + 2. Later: seq_group cannot be allocated. + The capacity of allocator is larger than seq_group required. + 3. Never: seq_group can never be allocated. + The seq_group is too large to allocated in GPU. + """ + OK = enum.auto() + LATER = enum.auto() + NEVER = enum.auto() + + +class BlockSpaceManager: + """Manages the mapping between logical and physical token blocks.""" + + def __init__( + self, + block_size: int, + num_gpu_blocks: int, + num_cpu_blocks: int, + watermark: float = 0.01, + sliding_window: Optional[int] = None, + ) -> None: + self.block_size = block_size + self.num_total_gpu_blocks = num_gpu_blocks + self.num_total_cpu_blocks = num_cpu_blocks + + self.block_sliding_window = None + if sliding_window is not None: + assert sliding_window % block_size == 0, (sliding_window, + block_size) + self.block_sliding_window = sliding_window // block_size + + self.watermark = watermark + assert watermark >= 0.0 + + self.watermark_blocks = int(watermark * num_gpu_blocks) + self.gpu_allocator = BlockAllocator(Device.GPU, block_size, + num_gpu_blocks) + self.cpu_allocator = BlockAllocator(Device.CPU, block_size, + num_cpu_blocks) + # Mapping: seq_id -> BlockTable. + self.block_tables: Dict[int, BlockTable] = {} + + def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: + # FIXME(woosuk): Here we assume that all sequences in the group share + # the same prompt. This may not be true for preempted sequences. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + num_required_blocks = len(seq.logical_token_blocks) + + if seq_group.prefix is not None and seq_group.prefix.allocated: + num_required_blocks -= seq_group.prefix.get_num_blocks() + + if self.block_sliding_window is not None: + num_required_blocks = min(num_required_blocks, + self.block_sliding_window) + num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + + # Use watermark to avoid frequent cache eviction. + if (self.num_total_gpu_blocks - num_required_blocks < + self.watermark_blocks): + return AllocStatus.NEVER + if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: + return AllocStatus.OK + else: + return AllocStatus.LATER + + def allocate(self, seq_group: SequenceGroup) -> None: + # NOTE: Here we assume that all sequences in the group have the same + # prompt. + seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] + + # Allocate new physical token blocks that will store the prompt tokens. + num_prompt_blocks = len(seq.logical_token_blocks) + + block_table: BlockTable = [] + prefix_block_table: BlockTable = [] + num_prefix_blocks = 0 + + prefix = seq_group.prefix + if prefix is not None and prefix.allocated: + # Prefix has already been allocated. Use the existing block table. + num_prompt_blocks -= prefix.get_num_blocks() + for block in prefix.block_table: + block.ref_count += seq_group.num_seqs() + block_table.append(block) + + for logical_idx in range(num_prompt_blocks): + if (self.block_sliding_window is not None + and logical_idx >= self.block_sliding_window): + block = block_table[logical_idx % self.block_sliding_window] + else: + block = self.gpu_allocator.allocate() + # Set the reference counts of the token blocks. + block.ref_count = seq_group.num_seqs() + block_table.append(block) + + if prefix is not None and not prefix.allocated: + # Allocate blocks for the prefix, we will compute the prefix's + # KV cache in this run. + num_prefix_blocks = prefix.get_num_blocks() + prefix_block_table = block_table[:num_prefix_blocks] + for block in prefix_block_table: + block.ref_count += 1 + prefix.set_block_table(prefix_block_table) + + # Assign the block table for each sequence. + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + self.block_tables[seq.seq_id] = block_table.copy() + + def can_append_slot(self, seq_group: SequenceGroup) -> bool: + # Simple heuristic: If there is at least one free block + # for each sequence, we can append. + num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() + num_seqs = seq_group.num_seqs(status=SequenceStatus.RUNNING) + return num_seqs <= num_free_gpu_blocks + + def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: + """Allocate a physical slot for a new token.""" + logical_blocks = seq.logical_token_blocks + block_table = self.block_tables[seq.seq_id] + + if len(block_table) < len(logical_blocks): + if (self.block_sliding_window + and len(block_table) >= self.block_sliding_window): + # reuse a block + block_table.append(block_table[len(block_table) % + self.block_sliding_window]) + else: + # The sequence has a new logical block. + # Allocate a new physical block. + block = self.gpu_allocator.allocate() + block_table.append(block) + return None + + # We want to append the token to the last physical block. + last_block = block_table[-1] + assert last_block.device == Device.GPU + if last_block.ref_count == 1: + # Not shared with other sequences. Appendable. + return None + else: + # The last block is shared with other sequences. + # Copy on Write: Allocate a new block and copy the tokens. + new_block = self.gpu_allocator.allocate() + block_table[-1] = new_block + self.gpu_allocator.free(last_block) + return last_block.block_number, new_block.block_number + + def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: + # NOTE: fork does not allocate a new physical block. + # Thus, it is always safe from OOM. + src_block_table = self.block_tables[parent_seq.seq_id] + self.block_tables[child_seq.seq_id] = src_block_table.copy() + for block in src_block_table: + block.ref_count += 1 + + def _get_physical_blocks( + self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: + # NOTE: Here, we assume that the physical blocks are only shared by + # the sequences in the same group. + blocks: Set[PhysicalTokenBlock] = set() + for seq in seq_group.get_seqs(): + if seq.is_finished(): + continue + blocks.update(self.block_tables[seq.seq_id]) + return list(blocks) + + def can_swap_in(self, seq_group: SequenceGroup) -> bool: + blocks = self._get_physical_blocks(seq_group) + num_swapped_seqs = seq_group.num_seqs(status=SequenceStatus.SWAPPED) + num_free_blocks = self.gpu_allocator.get_num_free_blocks() + # NOTE: Conservatively, we assume that every sequence will allocate + # at least one free block right after the swap-in. + # NOTE: This should match the logic in can_append_slot(). + num_required_blocks = len(blocks) + num_swapped_seqs + return num_free_blocks - num_required_blocks >= self.watermark_blocks + + def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]: + # CPU block -> GPU block. + if seq_group.prefix is not None: + # make sure to swap in the prefix first + assert seq_group.prefix.allocated and seq_group.prefix.computed + + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + new_block_table: BlockTable = [] + block_table = self.block_tables[seq.seq_id] + if seq_group.prefix is not None: + for block in seq_group.prefix.block_table: + new_block_table.append(block) + block.ref_count += 1 + + for cpu_block in block_table: + if cpu_block in mapping: + gpu_block = mapping[cpu_block] + gpu_block.ref_count += 1 + else: + gpu_block = self.gpu_allocator.allocate() + mapping[cpu_block] = gpu_block + new_block_table.append(gpu_block) + # Free the CPU block swapped in to GPU. + self.cpu_allocator.free(cpu_block) + self.block_tables[seq.seq_id] = new_block_table + + block_number_mapping = { + cpu_block.block_number: gpu_block.block_number + for cpu_block, gpu_block in mapping.items() + } + return block_number_mapping + + def can_swap_out(self, seq_group: SequenceGroup) -> bool: + blocks = self._get_physical_blocks(seq_group) + return len(blocks) <= self.cpu_allocator.get_num_free_blocks() + + def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]: + # GPU block -> CPU block. + mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + new_block_table: BlockTable = [] + block_table = self.block_tables[seq.seq_id] + + for gpu_block in block_table: + if (seq_group.prefix is not None + and gpu_block in seq_group.prefix.block_table): + # NOTE: We do not swap out the prefix blocks for now. + self.gpu_allocator.free(gpu_block) + continue + + if gpu_block in mapping: + cpu_block = mapping[gpu_block] + cpu_block.ref_count += 1 + else: + cpu_block = self.cpu_allocator.allocate() + mapping[gpu_block] = cpu_block + new_block_table.append(cpu_block) + # Free the GPU block swapped out to CPU. + self.gpu_allocator.free(gpu_block) + self.block_tables[seq.seq_id] = new_block_table + + block_number_mapping = { + gpu_block.block_number: cpu_block.block_number + for gpu_block, cpu_block in mapping.items() + } + return block_number_mapping + + def _free_block_table(self, block_table: BlockTable) -> None: + for block in set(block_table): + if block.device == Device.GPU: + self.gpu_allocator.free(block) + else: + self.cpu_allocator.free(block) + + def free(self, seq: Sequence) -> None: + if seq.seq_id not in self.block_tables: + # Already freed or haven't been scheduled yet. + return + block_table = self.block_tables[seq.seq_id] + self._free_block_table(block_table) + del self.block_tables[seq.seq_id] + + def reset(self) -> None: + for block_table in self.block_tables.values(): + self._free_block_table(block_table) + self.block_tables.clear() + + def get_block_table(self, seq: Sequence) -> List[int]: + block_table = self.block_tables[seq.seq_id] + return [block.block_number for block in block_table] + + def get_num_free_gpu_blocks(self) -> int: + return self.gpu_allocator.get_num_free_blocks() + + def get_num_free_cpu_blocks(self) -> int: + return self.cpu_allocator.get_num_free_blocks() diff --git a/vllm/core/policy.py b/vllm/core/policy.py new file mode 100644 index 0000000..2e9ebbd --- /dev/null +++ b/vllm/core/policy.py @@ -0,0 +1,47 @@ +from collections import deque +from typing import Deque + +from vllm.sequence import SequenceGroup + + +class Policy: + + def get_priority( + self, + now: float, + seq_group: SequenceGroup, + ) -> float: + raise NotImplementedError + + def sort_by_priority( + self, + now: float, + seq_groups: Deque[SequenceGroup], + ) -> Deque[SequenceGroup]: + return deque( + sorted( + seq_groups, + key=lambda seq_group: self.get_priority(now, seq_group), + reverse=True, + )) + + +class FCFS(Policy): + + def get_priority( + self, + now: float, + seq_group: SequenceGroup, + ) -> float: + return now - seq_group.metrics.arrival_time + + +class PolicyFactory: + + _POLICY_REGISTRY = { + 'fcfs': FCFS, + } + + @classmethod + def get_policy(cls, policy_name: str, **kwargs) -> Policy: + return cls._POLICY_REGISTRY[policy_name](**kwargs) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py new file mode 100644 index 0000000..5e7cc30 --- /dev/null +++ b/vllm/core/scheduler.py @@ -0,0 +1,498 @@ +from collections import deque +import enum +import time +from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set + +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig +from vllm.core.block_manager import AllocStatus, BlockSpaceManager +from vllm.core.policy import PolicyFactory +from vllm.lora.request import LoRARequest +from vllm.logger import init_logger +from vllm.sequence import (Sequence, SequenceData, SequenceGroup, + SequenceGroupMetadata, SequenceStatus) +from vllm.prefix import PrefixPool + +logger = init_logger(__name__) + + +class PreemptionMode(enum.Enum): + """Preemption modes. + + 1. Swapping: Swap out the blocks of the preempted sequences to CPU memory + and swap them back in when the sequences are resumed. + 2. Recomputation: Discard the blocks of the preempted sequences and + recompute them when the sequences are resumed, treating the sequences as + new prompts. + """ + SWAP = enum.auto() + RECOMPUTE = enum.auto() + + +class SchedulerOutputs: + + def __init__( + self, + scheduled_seq_groups: Iterable[SequenceGroup], + prompt_run: bool, + num_batched_tokens: int, + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ignored_seq_groups: List[SequenceGroup], + ) -> None: + self.scheduled_seq_groups = scheduled_seq_groups + self.prompt_run = prompt_run + self.num_batched_tokens = num_batched_tokens + self.blocks_to_swap_in = blocks_to_swap_in + self.blocks_to_swap_out = blocks_to_swap_out + self.blocks_to_copy = blocks_to_copy + # Swap in and swap out should never happen at the same time. + assert not (blocks_to_swap_in and blocks_to_swap_out) + self.ignored_seq_groups = ignored_seq_groups + + self.num_loras = len(self.lora_requests) + if self.num_loras > 0: + self._sort_by_lora_ids() + + def is_empty(self) -> bool: + # NOTE: We do not consider the ignored sequence groups. + return (not self.scheduled_seq_groups and not self.blocks_to_swap_in + and not self.blocks_to_swap_out and not self.blocks_to_copy) + + def _sort_by_lora_ids(self) -> bool: + self.scheduled_seq_groups = sorted( + self.scheduled_seq_groups, + key=lambda g: (g.lora_request.lora_int_id + if g.lora_request else 0, g.request_id)) + + @property + def lora_requests(self) -> Set[LoRARequest]: + return {g.lora_request for g in self.scheduled_seq_groups} + + +class Scheduler: + + def __init__( + self, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], + ) -> None: + self.scheduler_config = scheduler_config + self.cache_config = cache_config + # Note for LoRA scheduling: the current policy is extremely + # simple and NOT fair. It can lead to starvation of some + # LoRAs. This should be improved in the future. + self.lora_config = lora_config + + self.prompt_limit = min(self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens) + + # Instantiate the scheduling policy. + self.policy = PolicyFactory.get_policy(policy_name="fcfs") + # Create the block space manager. + self.block_manager = BlockSpaceManager( + block_size=self.cache_config.block_size, + num_gpu_blocks=self.cache_config.num_gpu_blocks, + num_cpu_blocks=self.cache_config.num_cpu_blocks, + sliding_window=self.cache_config.sliding_window) + + # Create the prefix pool to cache the prefixes. + self.prefix_pool = PrefixPool(self.cache_config.block_size) + + # Sequence groups in the WAITING state. + self.waiting: Deque[SequenceGroup] = deque() + # Sequence groups in the RUNNING state. + self.running: Deque[SequenceGroup] = deque() + # Sequence groups in the SWAPPED state. + self.swapped: Deque[SequenceGroup] = deque() + + @property + def lora_enabled(self) -> bool: + return bool(self.lora_config) + + def add_seq_group(self, seq_group: SequenceGroup) -> None: + # Add sequence groups to the waiting queue. + self.waiting.append(seq_group) + + def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a sequence group with the given ID. + + Check if the sequence group with the given ID + is present in any of the state queue. + If present, remove the sequence group from the state queue. + Also, if any of the sequences in the sequence group is not finished, + free the sequence with status `FINISHED_ABORTED`. + Otherwise, do nothing. + + Args: + request_id: The ID(s) of the sequence group to abort. + """ + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) + for state_queue in [self.waiting, self.running, self.swapped]: + aborted_groups: List[SequenceGroup] = [] + for seq_group in state_queue: + if not request_ids: + # Using 'break' here may add two extra iterations, + # but is acceptable to reduce complexity . + break + if seq_group.request_id in request_ids: + # Appending aborted group into pending list. + aborted_groups.append(seq_group) + request_ids.remove(seq_group.request_id) + for aborted_group in aborted_groups: + # Remove the sequence group from the state queue. + state_queue.remove(aborted_group) + for seq in aborted_group.get_seqs(): + if seq.is_finished(): + continue + seq.status = SequenceStatus.FINISHED_ABORTED + self.free_seq(seq) + + def has_unfinished_seqs(self) -> bool: + return self.waiting or self.running or self.swapped + + def get_num_unfinished_seq_groups(self) -> int: + return len(self.waiting) + len(self.running) + len(self.swapped) + + def _schedule(self) -> SchedulerOutputs: + # Blocks that need to be swapped or copied before model execution. + blocks_to_swap_in: Dict[int, int] = {} + blocks_to_swap_out: Dict[int, int] = {} + blocks_to_copy: Dict[int, List[int]] = {} + + # Fix the current time. + now = time.monotonic() + + # Join waiting sequences if possible. + if not self.swapped: + ignored_seq_groups: List[SequenceGroup] = [] + scheduled: List[SequenceGroup] = [] + # The total number of sequences on the fly, including the + # requests in the generation phase. + num_curr_seqs = sum(seq_group.get_max_num_running_seqs() + for seq_group in self.running) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None + seq_lens: List[int] = [] + + # Optimization: We do not sort the waiting queue since the preempted + # sequence groups are added to the front and the new sequence groups + # are added to the back. + leftover_waiting_sequences = deque() + while self.waiting: + seq_group = self.waiting[0] + waiting_seqs = seq_group.get_seqs( + status=SequenceStatus.WAITING) + assert len(waiting_seqs) == 1, ( + "Waiting sequence group should have only one prompt " + "sequence.") + num_prompt_tokens = waiting_seqs[0].get_len() + if num_prompt_tokens > self.prompt_limit: + logger.warning( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + f" and exceeds limit of {self.prompt_limit}") + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + self.waiting.popleft() + continue + + # If the sequence group cannot be allocated, stop. + can_allocate = self.block_manager.can_allocate(seq_group) + if can_allocate == AllocStatus.LATER: + break + elif can_allocate == AllocStatus.NEVER: + logger.warning( + f"Input prompt ({num_prompt_tokens} tokens) is too long" + f" and exceeds the capacity of block_manager") + for seq in waiting_seqs: + seq.status = SequenceStatus.FINISHED_IGNORED + ignored_seq_groups.append(seq_group) + self.waiting.popleft() + continue + + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if lora_int_id > 0 and lora_int_id not in curr_loras and len( + curr_loras) >= self.lora_config.max_loras: + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_waiting_sequences.appendleft(seq_group) + self.waiting.popleft() + continue + + # If the number of batched tokens exceeds the limit, stop. + new_seq_lens = seq_lens + [num_prompt_tokens] + num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) + if (num_batched_tokens > + self.scheduler_config.max_num_batched_tokens): + break + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + if (num_curr_seqs + num_new_seqs > + self.scheduler_config.max_num_seqs): + break + + num_paddings = num_batched_tokens - sum(new_seq_lens) + if num_paddings > self.scheduler_config.max_paddings: + break + seq_lens = new_seq_lens + + if lora_int_id > 0: + curr_loras.add(lora_int_id) + self.waiting.popleft() + self._allocate(seq_group) + self.running.append(seq_group) + num_curr_seqs += num_new_seqs + scheduled.append(seq_group) + + self.waiting.extendleft(leftover_waiting_sequences) + + if scheduled or ignored_seq_groups: + scheduler_outputs = SchedulerOutputs( + scheduled_seq_groups=scheduled, + prompt_run=True, + num_batched_tokens=len(seq_lens) * + max(seq_lens) if seq_lens else 0, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=ignored_seq_groups, + ) + return scheduler_outputs + + # NOTE(woosuk): Preemption happens only when there is no available slot + # to keep all the sequence groups in the RUNNING state. + # In this case, the policy is responsible for deciding which sequence + # groups to preempt. + self.running = self.policy.sort_by_priority(now, self.running) + + # Reserve new token slots for the running sequence groups. + running: Deque[SequenceGroup] = deque() + preempted: List[SequenceGroup] = [] + while self.running: + seq_group = self.running.popleft() + while not self.block_manager.can_append_slot(seq_group): + if self.running: + # Preempt the lowest-priority sequence groups. + victim_seq_group = self.running.pop() + self._preempt(victim_seq_group, blocks_to_swap_out) + preempted.append(victim_seq_group) + else: + # No other sequence groups can be preempted. + # Preempt the current sequence group. + self._preempt(seq_group, blocks_to_swap_out) + preempted.append(seq_group) + break + else: + # Append new slots to the sequence group. + self._append_slot(seq_group, blocks_to_copy) + running.append(seq_group) + self.running = running + + # Swap in the sequence groups in the SWAPPED state if possible. + self.swapped = self.policy.sort_by_priority(now, self.swapped) + if not preempted: + num_curr_seqs = sum(seq_group.get_max_num_running_seqs() + for seq_group in self.running) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None + + leftover_swapped = deque() + + while self.swapped: + seq_group = self.swapped[0] + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if lora_int_id > 0 and lora_int_id not in curr_loras and len( + curr_loras) >= self.lora_config.max_loras: + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_swapped.appendleft(seq_group) + self.swapped.popleft() + continue + + # If the sequence group cannot be swapped in, stop. + if not self.block_manager.can_swap_in(seq_group): + break + + # The total number of sequences in the RUNNING state should not + # exceed the maximum number of sequences. + num_new_seqs = seq_group.get_max_num_running_seqs() + if (num_curr_seqs + num_new_seqs > + self.scheduler_config.max_num_seqs): + break + + if lora_int_id > 0: + curr_loras.add(lora_int_id) + self.swapped.popleft() + self._swap_in(seq_group, blocks_to_swap_in) + self._append_slot(seq_group, blocks_to_copy) + num_curr_seqs += num_new_seqs + self.running.append(seq_group) + + self.swapped.extendleft(leftover_swapped) + + # Each sequence in the generation phase only takes one token slot. + # Therefore, the number of batched tokens is equal to the number of + # sequences in the RUNNING state. + num_batched_tokens = sum( + seq_group.num_seqs(status=SequenceStatus.RUNNING) + for seq_group in self.running) + + scheduler_outputs = SchedulerOutputs( + scheduled_seq_groups=self.running, + prompt_run=False, + num_batched_tokens=num_batched_tokens, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ignored_seq_groups=[], + ) + return scheduler_outputs + + def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: + # Schedule sequence groups. + # This function call changes the internal states of the scheduler + # such as self.running, self.swapped, and self.waiting. + scheduler_outputs = self._schedule() + now = time.time() + + # Create input data structures. + seq_group_metadata_list: List[SequenceGroupMetadata] = [] + for seq_group in scheduler_outputs.scheduled_seq_groups: + seq_group.maybe_set_first_scheduled_time(now) + + seq_data: Dict[int, SequenceData] = {} + block_tables: Dict[int, List[int]] = {} + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq_id = seq.seq_id + seq_data[seq_id] = seq.data + block_tables[seq_id] = self.block_manager.get_block_table(seq) + + seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group.request_id, + is_prompt=scheduler_outputs.prompt_run, + seq_data=seq_data, + sampling_params=seq_group.sampling_params, + block_tables=block_tables, + lora_request=seq_group.lora_request, + prefix=seq_group.prefix, + state=seq_group.state, + ) + seq_group_metadata_list.append(seq_group_metadata) + return seq_group_metadata_list, scheduler_outputs + + def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: + self.block_manager.fork(parent_seq, child_seq) + + def free_seq(self, seq: Sequence) -> None: + self.block_manager.free(seq) + + def free_finished_seq_groups(self) -> None: + self.running = deque(seq_group for seq_group in self.running + if not seq_group.is_finished()) + + def _allocate(self, seq_group: SequenceGroup) -> None: + self.block_manager.allocate(seq_group) + for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): + seq.status = SequenceStatus.RUNNING + + def _append_slot( + self, + seq_group: SequenceGroup, + blocks_to_copy: Dict[int, List[int]], + ) -> None: + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + ret = self.block_manager.append_slot(seq) + if ret is not None: + src_block, dst_block = ret + if src_block in blocks_to_copy: + blocks_to_copy[src_block].append(dst_block) + else: + blocks_to_copy[src_block] = [dst_block] + + def _preempt( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + preemption_mode: Optional[PreemptionMode] = None, + ) -> None: + # If preemption mode is not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + # FIXME(woosuk): This makes our scheduling policy a bit bizarre. + # As swapped sequences are prioritized over waiting sequences, + # sequence groups with multiple sequences are implicitly prioritized + # over sequence groups with a single sequence. + # TODO(woosuk): Support recomputation for sequence groups with multiple + # sequences. This may require a more sophisticated CUDA kernel. + if preemption_mode is None: + if seq_group.get_max_num_running_seqs() == 1: + preemption_mode = PreemptionMode.RECOMPUTE + else: + preemption_mode = PreemptionMode.SWAP + if preemption_mode == PreemptionMode.RECOMPUTE: + self._preempt_by_recompute(seq_group) + elif preemption_mode == PreemptionMode.SWAP: + self._preempt_by_swap(seq_group, blocks_to_swap_out) + else: + raise AssertionError("Invalid preemption mode.") + + def _preempt_by_recompute( + self, + seq_group: SequenceGroup, + ) -> None: + seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + assert len(seqs) == 1 + for seq in seqs: + seq.status = SequenceStatus.WAITING + self.block_manager.free(seq) + # NOTE: For FCFS, we insert the preempted sequence group to the front + # of the waiting queue. + self.waiting.appendleft(seq_group) + + def _preempt_by_swap( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: + self._swap_out(seq_group, blocks_to_swap_out) + self.swapped.append(seq_group) + + def _swap_in( + self, + seq_group: SequenceGroup, + blocks_to_swap_in: Dict[int, int], + ) -> None: + mapping = self.block_manager.swap_in(seq_group) + blocks_to_swap_in.update(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.SWAPPED): + seq.status = SequenceStatus.RUNNING + + def _swap_out( + self, + seq_group: SequenceGroup, + blocks_to_swap_out: Dict[int, int], + ) -> None: + if not self.block_manager.can_swap_out(seq_group): + # FIXME(woosuk): Abort the sequence group instead of aborting the + # entire engine. + raise RuntimeError( + "Aborted due to the lack of CPU swap space. Please increase " + "the swap space to avoid this error.") + mapping = self.block_manager.swap_out(seq_group) + blocks_to_swap_out.update(mapping) + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): + seq.status = SequenceStatus.SWAPPED diff --git a/vllm/engine/__init__.py b/vllm/engine/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py new file mode 100644 index 0000000..a5c824d --- /dev/null +++ b/vllm/engine/arg_utils.py @@ -0,0 +1,341 @@ +import argparse +import dataclasses +from dataclasses import dataclass +from typing import Optional, Tuple + +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) + + +@dataclass +class EngineArgs: + """Arguments for vLLM engine.""" + model: str + tokenizer: Optional[str] = None + tokenizer_mode: str = 'auto' + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = 'auto' + dtype: str = 'auto' + kv_cache_dtype: str = 'auto' + seed: int = 0 + max_model_len: Optional[int] = None + worker_use_ray: bool = False + pipeline_parallel_size: int = 1 + tensor_parallel_size: int = 1 + max_parallel_loading_workers: Optional[int] = None + block_size: int = 16 + swap_space: int = 4 # GiB + gpu_memory_utilization: float = 0.90 + max_num_batched_tokens: Optional[int] = None + max_num_seqs: int = 256 + max_paddings: int = 256 + disable_log_stats: bool = False + revision: Optional[str] = None + code_revision: Optional[str] = None + tokenizer_revision: Optional[str] = None + quantization: Optional[str] = None + enforce_eager: bool = False + max_context_len_to_capture: int = 8192 + disable_custom_all_reduce: bool = False + enable_lora: bool = False + max_loras: int = 1 + max_lora_rank: int = 16 + lora_extra_vocab_size: int = 256 + lora_dtype = 'auto' + max_cpu_loras: Optional[int] = None + device: str = 'auto' + + def __post_init__(self): + if self.tokenizer is None: + self.tokenizer = self.model + + @staticmethod + def add_cli_args( + parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + """Shared CLI arguments for vLLM engine.""" + + # NOTE: If you update any of the arguments below, please also + # make sure to update docs/source/models/engine_args.rst + + # Model arguments + parser.add_argument( + '--model', + type=str, + default='facebook/opt-125m', + help='name or path of the huggingface model to use') + parser.add_argument( + '--tokenizer', + type=str, + default=EngineArgs.tokenizer, + help='name or path of the huggingface tokenizer to use') + parser.add_argument( + '--revision', + type=str, + default=None, + help='the specific model version to use. It can be a branch ' + 'name, a tag name, or a commit id. If unspecified, will use ' + 'the default version.') + parser.add_argument( + '--code-revision', + type=str, + default=None, + help='the specific revision to use for the model code on ' + 'Hugging Face Hub. It can be a branch name, a tag name, or a ' + 'commit id. If unspecified, will use the default version.') + parser.add_argument( + '--tokenizer-revision', + type=str, + default=None, + help='the specific tokenizer version to use. It can be a branch ' + 'name, a tag name, or a commit id. If unspecified, will use ' + 'the default version.') + parser.add_argument('--tokenizer-mode', + type=str, + default=EngineArgs.tokenizer_mode, + choices=['auto', 'slow'], + help='tokenizer mode. "auto" will use the fast ' + 'tokenizer if available, and "slow" will ' + 'always use the slow tokenizer.') + parser.add_argument('--trust-remote-code', + action='store_true', + help='trust remote code from huggingface') + parser.add_argument('--download-dir', + type=str, + default=EngineArgs.download_dir, + help='directory to download and load the weights, ' + 'default to the default cache dir of ' + 'huggingface') + parser.add_argument( + '--load-format', + type=str, + default=EngineArgs.load_format, + choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], + help='The format of the model weights to load. ' + '"auto" will try to load the weights in the safetensors format ' + 'and fall back to the pytorch bin format if safetensors format ' + 'is not available. ' + '"pt" will load the weights in the pytorch bin format. ' + '"safetensors" will load the weights in the safetensors format. ' + '"npcache" will load the weights in pytorch format and store ' + 'a numpy cache to speed up the loading. ' + '"dummy" will initialize the weights with random values, ' + 'which is mainly for profiling.') + parser.add_argument( + '--dtype', + type=str, + default=EngineArgs.dtype, + choices=[ + 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' + ], + help='data type for model weights and activations. ' + 'The "auto" option will use FP16 precision ' + 'for FP32 and FP16 models, and BF16 precision ' + 'for BF16 models.') + parser.add_argument( + '--kv-cache-dtype', + type=str, + choices=['auto', 'fp8_e5m2'], + default=EngineArgs.kv_cache_dtype, + help='Data type for kv cache storage. If "auto", will use model ' + 'data type. Note FP8 is not supported when cuda version is ' + 'lower than 11.8.') + parser.add_argument('--max-model-len', + type=int, + default=EngineArgs.max_model_len, + help='model context length. If unspecified, ' + 'will be automatically derived from the model.') + # Parallel arguments + parser.add_argument('--worker-use-ray', + action='store_true', + help='use Ray for distributed serving, will be ' + 'automatically set when using more than 1 GPU') + parser.add_argument('--pipeline-parallel-size', + '-pp', + type=int, + default=EngineArgs.pipeline_parallel_size, + help='number of pipeline stages') + parser.add_argument('--tensor-parallel-size', + '-tp', + type=int, + default=EngineArgs.tensor_parallel_size, + help='number of tensor parallel replicas') + parser.add_argument( + '--max-parallel-loading-workers', + type=int, + default=EngineArgs.max_parallel_loading_workers, + help='load model sequentially in multiple batches, ' + 'to avoid RAM OOM when using tensor ' + 'parallel and large models') + # KV cache arguments + parser.add_argument('--block-size', + type=int, + default=EngineArgs.block_size, + choices=[16], + help='token block size') + parser.add_argument('--seed', + type=int, + default=EngineArgs.seed, + help='random seed') + parser.add_argument('--swap-space', + type=int, + default=EngineArgs.swap_space, + help='CPU swap space size (GiB) per GPU') + parser.add_argument( + '--gpu-memory-utilization', + type=float, + default=EngineArgs.gpu_memory_utilization, + help='the fraction of GPU memory to be used for ' + 'the model executor, which can range from 0 to 1.' + 'If unspecified, will use the default value of 0.9.') + parser.add_argument('--max-num-batched-tokens', + type=int, + default=EngineArgs.max_num_batched_tokens, + help='maximum number of batched tokens per ' + 'iteration') + parser.add_argument('--max-num-seqs', + type=int, + default=EngineArgs.max_num_seqs, + help='maximum number of sequences per iteration') + parser.add_argument('--max-paddings', + type=int, + default=EngineArgs.max_paddings, + help='maximum number of paddings in a batch') + parser.add_argument('--disable-log-stats', + action='store_true', + help='disable logging statistics') + # Quantization settings. + parser.add_argument('--quantization', + '-q', + type=str, + choices=['awq', 'gptq', 'squeezellm', 'smoothquant',None], + default=EngineArgs.quantization, + help='Method used to quantize the weights. If ' + 'None, we first check the `quantization_config` ' + 'attribute in the model config file. If that is ' + 'None, we assume the model weights are not ' + 'quantized and use `dtype` to determine the data ' + 'type of the weights.') + parser.add_argument('--enforce-eager', + action='store_true', + help='Always use eager-mode PyTorch. If False, ' + 'will use eager mode and CUDA graph in hybrid ' + 'for maximal performance and flexibility.') + parser.add_argument('--max-context-len-to-capture', + type=int, + default=EngineArgs.max_context_len_to_capture, + help='maximum context length covered by CUDA ' + 'graphs. When a sequence has context length ' + 'larger than this, we fall back to eager mode.') + parser.add_argument('--disable-custom-all-reduce', + action='store_true', + default=EngineArgs.disable_custom_all_reduce, + help='See ParallelConfig') + # LoRA related configs + parser.add_argument('--enable-lora', + action='store_true', + help='If True, enable handling of LoRA adapters.') + parser.add_argument('--max-loras', + type=int, + default=EngineArgs.max_loras, + help='Max number of LoRAs in a single batch.') + parser.add_argument('--max-lora-rank', + type=int, + default=EngineArgs.max_lora_rank, + help='Max LoRA rank.') + parser.add_argument( + '--lora-extra-vocab-size', + type=int, + default=EngineArgs.lora_extra_vocab_size, + help=('Maximum size of extra vocabulary that can be ' + 'present in a LoRA adapter (added to the base ' + 'model vocabulary).')) + parser.add_argument( + '--lora-dtype', + type=str, + default=EngineArgs.lora_dtype, + choices=['auto', 'float16', 'bfloat16', 'float32'], + help=('Data type for LoRA. If auto, will default to ' + 'base model dtype.')) + parser.add_argument( + '--max-cpu-loras', + type=int, + default=EngineArgs.max_cpu_loras, + help=('Maximum number of LoRAs to store in CPU memory. ' + 'Must be >= than max_num_seqs. ' + 'Defaults to max_num_seqs.')) + parser.add_argument("--device", + type=str, + default=EngineArgs.device, + choices=["auto", "cuda", "neuron"], + help='Device type for vLLM execution.') + return parser + + @classmethod + def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': + # Get the list of attributes of this dataclass. + attrs = [attr.name for attr in dataclasses.fields(cls)] + # Set the attributes from the parsed arguments. + engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + return engine_args + + def create_engine_configs( + self, + ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, + DeviceConfig, Optional[LoRAConfig]]: + device_config = DeviceConfig(self.device) + model_config = ModelConfig( + self.model, self.tokenizer, self.tokenizer_mode, + self.trust_remote_code, self.download_dir, self.load_format, + self.dtype, self.seed, self.revision, self.code_revision, + self.tokenizer_revision, self.max_model_len, self.quantization, + self.enforce_eager, self.max_context_len_to_capture) + cache_config = CacheConfig(self.block_size, + self.gpu_memory_utilization, + self.swap_space, self.kv_cache_dtype, + model_config.get_sliding_window()) + parallel_config = ParallelConfig(self.pipeline_parallel_size, + self.tensor_parallel_size, + self.worker_use_ray, + self.max_parallel_loading_workers, + self.disable_custom_all_reduce) + scheduler_config = SchedulerConfig(self.max_num_batched_tokens, + self.max_num_seqs, + model_config.max_model_len, + self.max_paddings) + lora_config = LoRAConfig( + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras + and self.max_cpu_loras > 0 else None) if self.enable_lora else None + return (model_config, cache_config, parallel_config, scheduler_config, + device_config, lora_config) + + +@dataclass +class AsyncEngineArgs(EngineArgs): + """Arguments for asynchronous vLLM engine.""" + engine_use_ray: bool = False + disable_log_requests: bool = False + max_log_len: Optional[int] = None + + @staticmethod + def add_cli_args( + parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + parser = EngineArgs.add_cli_args(parser) + parser.add_argument('--engine-use-ray', + action='store_true', + help='use Ray to start the LLM engine in a ' + 'separate process as the server process.') + parser.add_argument('--disable-log-requests', + action='store_true', + help='disable logging requests') + parser.add_argument('--max-log-len', + type=int, + default=None, + help='max number of prompt characters or prompt ' + 'ID numbers being printed in log. ' + 'Default: unlimited.') + return parser diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py new file mode 100644 index 0000000..530314a --- /dev/null +++ b/vllm/engine/async_llm_engine.py @@ -0,0 +1,689 @@ +import asyncio +import time +from functools import partial +from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, + Union, AsyncIterator) + +from vllm.lora.request import LoRARequest +from vllm.config import ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.engine.ray_utils import initialize_cluster, ray +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams + +logger = init_logger(__name__) + + +class AsyncEngineDeadError(RuntimeError): + pass + + +def _raise_exception_on_finish(task: asyncio.Task, + request_tracker: "RequestTracker") -> None: + msg = ("Task finished unexpectedly. This should never happen! " + "Please open an issue on Github.") + try: + try: + task.result() + except asyncio.CancelledError: + return + except Exception as exc: + raise AsyncEngineDeadError( + msg + " See stack trace above for the actual cause.") from exc + raise AsyncEngineDeadError(msg) + except Exception as exc: + request_tracker.propagate_exception(exc) + raise exc + + +class AsyncStream: + """A stream of RequestOutputs for a request that can be + iterated over asynchronously.""" + + def __init__(self, request_id: str) -> None: + self.request_id = request_id + self._queue = asyncio.Queue() + self._finished = False + + def put(self, item: RequestOutput) -> None: + if self._finished: + return + self._queue.put_nowait(item) + + def finish(self) -> None: + self._queue.put_nowait(StopAsyncIteration()) + self._finished = True + + @property + def finished(self) -> bool: + return self._finished + + def __aiter__(self): + return self + + async def __anext__(self) -> RequestOutput: + result = await self._queue.get() + if isinstance(result, Exception): + raise result + return result + + +class RequestTracker: + """Synchronous abstraction for tracking requests.""" + + def __init__(self) -> None: + self._request_streams: Dict[str, AsyncStream] = {} + self._finished_requests: asyncio.Queue[str] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[AsyncStream, + dict]] = asyncio.Queue() + self.new_requests_event = None + + def __contains__(self, item): + return item in self._request_streams + + def init_event(self): + self.new_requests_event = asyncio.Event() + + def propagate_exception(self, + exc: Exception, + request_id: Optional[str] = None) -> None: + """Propagate an exception to request streams + (all if request_id is None).""" + if request_id is not None: + self._request_streams[request_id].put(exc) + else: + for stream in self._request_streams.values(): + stream.put(exc) + + def process_request_output(self, + request_output: RequestOutput, + *, + verbose: bool = False) -> None: + """Process a request output from the engine.""" + request_id = request_output.request_id + + self._request_streams[request_id].put(request_output) + if request_output.finished: + if verbose: + logger.info(f"Finished request {request_id}.") + self.abort_request(request_id) + + def add_request(self, request_id: str, + **engine_add_request_kwargs) -> AsyncStream: + """Add a request to be sent to the engine on the next background + loop iteration.""" + if request_id in self._request_streams: + raise KeyError(f"Request {request_id} already exists.") + + stream = AsyncStream(request_id) + self._new_requests.put_nowait((stream, { + "request_id": request_id, + **engine_add_request_kwargs + })) + + self.new_requests_event.set() + + return stream + + def abort_request(self, request_id: str, *, verbose: bool = False) -> None: + """Abort a request during next background loop iteration.""" + if verbose: + logger.info(f"Aborted request {request_id}.") + + self._finished_requests.put_nowait(request_id) + + if request_id not in self._request_streams or self._request_streams[ + request_id].finished: + # The request has already finished or been aborted. + return + + self._request_streams[request_id].finish() + + def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: + """Get the new requests and finished requests to be + sent to the engine.""" + new_requests: List[Dict] = [] + finished_requests: Set[str] = set() + + while not self._finished_requests.empty(): + request_id = self._finished_requests.get_nowait() + finished_requests.add(request_id) + self._request_streams.pop(request_id, None) + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + if stream.request_id in finished_requests: + # The request has already been aborted. + stream.finish() + continue + self._request_streams[stream.request_id] = stream + new_requests.append(new_request) + + self.new_requests_event.clear() + + return new_requests, finished_requests + + async def wait_for_new_requests(self): + await self.new_requests_event.wait() + + +class _AsyncLLMEngine(LLMEngine): + """Extension of LLMEngine to add async methods.""" + + async def step_async(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + The workers are ran asynchronously if possible. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + + # Execute the model. + output = (await self._run_workers_async( + "execute_model", + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + )) if not scheduler_outputs.is_empty() else [] + + return self._process_model_outputs(output, scheduler_outputs) + + # TODO align + """ + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + + if not scheduler_outputs.is_empty(): + # Execute the model. + all_outputs = await self._run_workers_async( + "execute_model", + driver_kwargs={ + "seq_group_metadata_list": seq_group_metadata_list, + "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, + "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, + "blocks_to_copy": scheduler_outputs.blocks_to_copy, + }) + + # Only the driver worker returns the sampling results. + output = all_outputs[0] + else: + output = [] + + return self._process_model_outputs(output, scheduler_outputs) + """ + + async def encode_request_async( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = await self.tokenizer.encode_async( + request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + + async def add_request_async( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + prefix_pos: Optional[int] = None, + ) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + + return self.add_request( + request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + arrival_time=arrival_time, + lora_request=lora_request, + prefix_pos=prefix_pos, + ) + + async def _run_workers_async( + self, + method: str, + *args, + get_all_outputs: bool = False, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + coros = [] + for worker in self.workers: + if self.parallel_config.worker_use_ray: + coros.append( + worker.execute_method.remote(method, *args, **kwargs)) + else: + executor = getattr(worker, method) + coros.append(asyncio.get_event_loop().run_in_executor( + None, partial(executor, *args, **kwargs))) + + all_outputs = await asyncio.gather(*coros) + + if get_all_outputs: + return all_outputs + + # Make sure all workers have the same results. + output = all_outputs[0] + for other_output in all_outputs[1:]: + assert output == other_output + return output + + # TODO align + """ + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[List[Any]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + coros = [] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Run the driver worker asynchronously. + driver_executor = getattr(self.driver_worker, method) + coros.append(asyncio.get_event_loop().run_in_executor( + None, partial(driver_executor, *driver_args, **driver_kwargs))) + + # Run the ray workers asynchronously. + for worker in self.workers: + coros.append(worker.execute_method.remote(method, *args, **kwargs)) + + all_outputs = await asyncio.gather(*coros) + return all_outputs + """ + + +class AsyncLLMEngine: + """An asynchronous wrapper for LLMEngine. + + This class is used to wrap the LLMEngine class to make it asynchronous. It + uses asyncio to create a background loop that keeps processing incoming + requests. The LLMEngine is kicked by the generate method when there + are requests in the waiting queue. The generate method yields the outputs + from the LLMEngine to the caller. + + NOTE: For the comprehensive list of arguments, see `LLMEngine`. + + Args: + worker_use_ray: Whether to use Ray for model workers. Required for + distributed execution. Should be the same as + `parallel_config.worker_use_ray`. + engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the + async frontend will be executed in a separate process as the + model workers. + log_requests: Whether to log the requests. + max_log_len: Maximum number of prompt characters or prompt ID numbers + being printed in log. + start_engine_loop: If True, the background task to run the engine + will be automatically started in the generate call. + *args: Arguments for LLMEngine. + *kwargs: Arguments for LLMEngine. + """ + + _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine + + def __init__(self, + worker_use_ray: bool, + engine_use_ray: bool, + *args, + log_requests: bool = True, + max_log_len: Optional[int] = None, + start_engine_loop: bool = True, + **kwargs) -> None: + self.worker_use_ray = worker_use_ray + self.engine_use_ray = engine_use_ray + self.log_requests = log_requests + self.max_log_len = max_log_len + self.engine = self._init_engine(*args, **kwargs) + + self.background_loop = None + # We need to keep a reference to unshielded + # task as well to prevent it from being garbage + # collected + self._background_loop_unshielded = None + self.start_engine_loop = start_engine_loop + self._request_tracker = RequestTracker() + + @property + def is_running(self) -> bool: + return (self.background_loop is not None + and not self.background_loop.done()) + + def get_tokenizer(self): + return self.engine.tokenizer.tokenizer + + def start_background_loop(self) -> None: + """Start the background loop.""" + if self.is_running: + raise RuntimeError("Background loop is already running.") + self._request_tracker.init_event() + + self._background_loop_unshielded = asyncio.get_event_loop( + ).create_task(self.run_engine_loop()) + self._background_loop_unshielded.add_done_callback( + partial(_raise_exception_on_finish, + request_tracker=self._request_tracker)) + self.background_loop = asyncio.shield(self._background_loop_unshielded) + + def _init_engine(self, *args, + **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]: + if not self.engine_use_ray: + engine_class = self._engine_class + elif self.worker_use_ray: + engine_class = ray.remote(num_cpus=0)(self._engine_class).remote + else: + # FIXME(woosuk): This is a bit hacky. Be careful when changing the + # order of the arguments. + cache_config = args[1] + parallel_config = args[2] + if parallel_config.tensor_parallel_size == 1: + num_gpus = cache_config.gpu_memory_utilization + else: + num_gpus = 1 + engine_class = ray.remote(num_gpus=num_gpus)( + self._engine_class).remote + return engine_class(*args, **kwargs) + + async def engine_step(self) -> bool: + """Kick the engine to process the waiting requests. + + Returns True if there are in-progress requests.""" + + new_requests, finished_requests = ( + self._request_tracker.get_new_and_finished_requests()) + + for new_request in new_requests: + # Add the request into the vLLM engine's waiting queue. + # TODO: Maybe add add_request_batch to reduce Ray overhead + if self.engine_use_ray: + await self.engine.add_request.remote(**new_request) + else: + await self.engine.add_request_async(**new_request) + + if finished_requests: + await self._engine_abort(finished_requests) + + if self.engine_use_ray: + request_outputs = await self.engine.step.remote() + else: + request_outputs = await self.engine.step_async() + + # Put the outputs into the corresponding streams. + for request_output in request_outputs: + self._request_tracker.process_request_output( + request_output, verbose=self.log_requests) + + return len(request_outputs) > 0 + + async def _engine_abort(self, request_ids: Iterable[str]): + if self.engine_use_ray: + await self.engine.abort_request.remote(request_ids) + else: + self.engine.abort_request(request_ids) + + async def run_engine_loop(self): + # Initialize the RequestTracker here so it uses the right event loop. + has_requests_in_progress = False + while True: + if not has_requests_in_progress: + await self._request_tracker.wait_for_new_requests() + has_requests_in_progress = await self.engine_step() + await asyncio.sleep(0) + + async def add_request( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + prefix_pos: Optional[int] = None, + ) -> AsyncStream: + if self.log_requests: + shortened_prompt = prompt + shortened_token_ids = prompt_token_ids + if self.max_log_len is not None: + if shortened_prompt is not None: + shortened_prompt = shortened_prompt[:self.max_log_len] + if shortened_token_ids is not None: + shortened_token_ids = shortened_token_ids[:self. + max_log_len] + logger.info(f"Received request {request_id}: " + f"prompt: {shortened_prompt!r}, " + f"prefix_pos: {prefix_pos}," + f"sampling_params: {sampling_params}, " + f"prompt_token_ids: {shortened_token_ids}, " + f"lora_request: {lora_request}.") + + if not self.is_running: + if self.start_engine_loop: + self.start_background_loop() + else: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + if arrival_time is None: + arrival_time = time.time() + + if self.engine_use_ray: + prompt_token_ids = await self.engine.encode_request_async.remote( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + else: + prompt_token_ids = await self.engine.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + + stream = self._request_tracker.add_request( + request_id, + prompt=prompt, + sampling_params=sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + lora_request=lora_request, + prefix_pos=prefix_pos) + + return stream + + async def generate( + self, + prompt: Optional[str], + sampling_params: SamplingParams, + request_id: str, + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + prefix_pos: Optional[int] = None, + ) -> AsyncIterator[RequestOutput]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + lora_request: LoRA request to use for generation, if any. + prefix_pos: If not None, we use the given position as the prefix + position for each prompt. We will cache the prefix's KV + cache and reuse it for the next request with the same prefix. + This is an experimental feature, and may be replaced with + automatic prefix caching in the future. + + Yields: + The output `RequestOutput` objects from the LLMEngine for the + request. + + Details: + - If the engine is not running, start the background loop, + which iteratively invokes + :meth:`~vllm.engine.async_llm_engine.AsyncLLMEngine.engine_step` + to process the waiting requests. + - Add the request to the engine's `RequestTracker`. + On the next background loop, this request will be sent to + the underlying engine. + Also, a corresponding `AsyncStream` will be created. + - Wait for the request outputs from `AsyncStream` and yield them. + + Example: + >>> # Please refer to entrypoints/api_server.py for + >>> # the complete example. + >>> + >>> # initialize the engine and the example input + >>> engine = AsyncLLMEngine.from_engine_args(engine_args) + >>> example_input = { + >>> "prompt": "What is LLM?", + >>> "stream": False, # assume the non-streaming case + >>> "temperature": 0.0, + >>> "request_id": 0, + >>> } + >>> + >>> # start the generation + >>> results_generator = engine.generate( + >>> example_input["prompt"], + >>> SamplingParams(temperature=example_input["temperature"]), + >>> example_input["request_id"]) + >>> + >>> # get the results + >>> final_output = None + >>> async for request_output in results_generator: + >>> if await request.is_disconnected(): + >>> # Abort the request if the client disconnects. + >>> await engine.abort(request_id) + >>> # Return or raise an error + >>> ... + >>> final_output = request_output + >>> + >>> # Process and return the final output + >>> ... + """ + # Preprocess the request. + # This should not be used for logging, as it is monotonic time. + arrival_time = time.monotonic() + + try: + stream = await self.add_request( + request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + lora_request=lora_request, + prefix_pos=prefix_pos, + ) + + async for request_output in stream: + yield request_output + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the + # request. + self._abort(request_id) + raise e + + async def abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + if not self.is_running: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError).") + + return self._abort(request_id) + + def _abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + self._request_tracker.abort_request(request_id, + verbose=self.log_requests) + + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_model_config.remote() + else: + return self.engine.get_model_config() + + @classmethod + def from_engine_args(cls, + engine_args: AsyncEngineArgs, + start_engine_loop: bool = True) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + engine_configs = engine_args.create_engine_configs() + parallel_config = engine_configs[2] + # Initialize the cluster. + placement_group = initialize_cluster(parallel_config, + engine_args.engine_use_ray) + # Create the async LLM engine. + engine = cls(parallel_config.worker_use_ray, + engine_args.engine_use_ray, + *engine_configs, + placement_group, + log_requests=not engine_args.disable_log_requests, + log_stats=not engine_args.disable_log_stats, + max_log_len=engine_args.max_log_len, + start_engine_loop=start_engine_loop) + return engine + + async def do_log_stats(self) -> None: + if self.engine_use_ray: + await self.engine.do_log_stats.remote() + else: + self.engine.do_log_stats() \ No newline at end of file diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py new file mode 100644 index 0000000..8a5d447 --- /dev/null +++ b/vllm/engine/llm_engine.py @@ -0,0 +1,1209 @@ +import copy +from collections import defaultdict +from functools import partial +import os +import time +import pickle +import importlib +from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, + Union) + +from vllm.lora.request import LoRARequest +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) +from vllm.core.scheduler import Scheduler, SchedulerOutputs +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.metrics import StatLogger, Stats +from vllm.engine.ray_utils import RayWorkerVllm, initialize_cluster, ray +from vllm.logger import init_logger +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, + SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.transformers_utils.tokenizer import (detokenize_incrementally, + TokenizerGroup) +from vllm.utils import (Counter, set_cuda_visible_devices, get_ip, + get_open_port, get_distributed_init_method) + +if ray: + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) +_LOCAL_LOGGING_INTERVAL_SEC = 5 + +# A map between the device type (in device config) to its worker module. +DEVICE_TO_WORKER_MODULE_MAP = { + "cuda": "vllm.worker.worker", + "neuron": "vllm.worker.neuron_worker", +} + +# If the env var is set, it uses the Ray's compiled DAG API +# which optimizes the control plane overhead. +# Run VLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. +USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) + + +class LLMEngine: + """An LLM engine that receives requests and generates texts. + + This is the main class for the vLLM engine. It receives requests + from clients and generates texts from the LLM. It includes a tokenizer, a + language model (possibly distributed across multiple GPUs), and GPU memory + space allocated for intermediate states (aka KV cache). This class utilizes + iteration-level scheduling and efficient memory management to maximize the + serving throughput. + + The `LLM` class wraps this class for offline batched inference and the + `AsyncLLMEngine` class wraps this class for online serving. + + NOTE: The config arguments are derived from the `EngineArgs` class. For the + comprehensive list of arguments, see `EngineArgs`. + + Args: + model_config: The configuration related to the LLM model. + cache_config: The configuration related to the KV cache memory + management. + parallel_config: The configuration related to distributed execution. + scheduler_config: The configuration related to the request scheduler. + device_config: The configuration related to the device. + placement_group: Ray placement group for distributed execution. + Required for distributed execution. + log_stats: Whether to log statistics. + """ + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + placement_group: Optional["PlacementGroup"], + log_stats: bool, + ) -> None: + logger.info( + "Initializing an LLM engine with config: " + f"model={model_config.model!r}, " + f"tokenizer={model_config.tokenizer!r}, " + f"tokenizer_mode={model_config.tokenizer_mode}, " + f"revision={model_config.revision}, " + f"tokenizer_revision={model_config.tokenizer_revision}, " + f"trust_remote_code={model_config.trust_remote_code}, " + f"dtype={model_config.dtype}, " + f"max_seq_len={model_config.max_model_len}, " + f"download_dir={model_config.download_dir!r}, " + f"load_format={model_config.load_format}, " + f"tensor_parallel_size={parallel_config.tensor_parallel_size}, " + f"disable_custom_all_reduce={parallel_config.disable_custom_all_reduce}, " + f"quantization={model_config.quantization}, " + f"enforce_eager={model_config.enforce_eager}, " + f"kv_cache_dtype={cache_config.cache_dtype}, " + f"device_config={device_config.device}, " + f"seed={model_config.seed})") + # TODO(woosuk): Print more configs in debug mode. + + self.model_config = model_config + self.cache_config = cache_config + self.lora_config = lora_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.log_stats = log_stats + self._verify_args() + + self._init_tokenizer() + self.seq_counter = Counter() + + # Create the parallel GPU workers. + if self.parallel_config.worker_use_ray: + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + self._init_workers_ray(placement_group) + else: + self._init_workers() + + # Profile the memory usage and initialize the cache. + self._init_cache() + + # Create the scheduler. + self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) + + # Metric Logging. + if self.log_stats: + self.stat_logger = StatLogger( + local_interval=_LOCAL_LOGGING_INTERVAL_SEC, + labels=dict(model_name=model_config.model)) + self.stat_logger.info("cache_config", self.cache_config) + + self.forward_dag = None + if USE_RAY_COMPILED_DAG: + self.forward_dag = self._compiled_ray_dag() + + def get_tokenizer_for_seq(self, sequence: Sequence): + return self.tokenizer.get_lora_tokenizer(sequence.lora_request) + + def _dispatch_worker(self): + worker_module = DEVICE_TO_WORKER_MODULE_MAP[ + self.device_config.device_type] + imported_worker = importlib.import_module(worker_module) + Worker = imported_worker.Worker + return Worker + + def _init_workers(self): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + Worker = self._dispatch_worker() + + assert self.parallel_config.world_size == 1, ( + "Ray is required if parallel_config.world_size > 1.") + + self.workers: List[Worker] = [] + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + worker = Worker( + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=True, + ) + self.workers.append(worker) + + self._run_workers( + "init_model", + get_all_outputs=True, + ) + self._run_workers( + "load_model", + get_all_outputs=True, + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers, + ) + # TODO align + """ + self._run_workers("init_model") + self._run_workers("load_model") + """ + + def _init_tokenizer(self, **tokenizer_init_kwargs): + init_kwargs = dict( + enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision) + init_kwargs.update(tokenizer_init_kwargs) + self.tokenizer: TokenizerGroup = TokenizerGroup( + self.model_config.tokenizer, **init_kwargs) + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + if self.parallel_config.tensor_parallel_size == 1: + num_gpus = self.cache_config.gpu_memory_utilization + else: + num_gpus = 1 + + self.driver_dummy_worker: RayWorkerVllm = None + self.workers: List[RayWorkerVllm] = [] + + driver_ip = get_ip() + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if not bundle.get("GPU", 0): + continue + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + worker = ray.remote( + num_cpus=0, + num_gpus=num_gpus, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerVllm).remote(self.model_config.trust_remote_code) + + worker_ip = ray.get(worker.get_node_ip.remote()) + if worker_ip == driver_ip and self.driver_dummy_worker is None: + # If the worker is on the same node as the driver, we use it + # as the resource holder for the driver process. + self.driver_dummy_worker = worker + # TODO align + self.workers.append(worker) + else: + self.workers.append(worker) + + if self.driver_dummy_worker is None: + raise ValueError( + "Ray does not allocate any GPUs on the driver node. Consider " + "adjusting the Ray placement group or running the driver on a " + "GPU node.") + + driver_node_id, driver_gpu_ids = ray.get( + self.driver_dummy_worker.get_node_and_gpu_ids.remote()) + worker_node_and_gpu_ids = ray.get( + [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) + + node_workers = defaultdict(list) + node_gpus = defaultdict(list) + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + # TODO align + """ + node_workers[driver_node_id].append(0) + node_gpus[driver_node_id].extend(driver_gpu_ids) + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, + start=1): + node_workers[node_id].append(i) + node_gpus[node_id].extend(gpu_ids) + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + """ + + # Set CUDA_VISIBLE_DEVICES for the driver. + set_cuda_visible_devices(node_gpus[driver_node_id]) + for worker, (node_id, _) in zip(self.workers, worker_node_and_gpu_ids): + worker.set_cuda_visible_devices.remote(node_gpus[node_id]) + + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + Worker = self._dispatch_worker() + + # Initialize torch distributed process group for the workers. + model_config = copy.deepcopy(self.model_config) + parallel_config = copy.deepcopy(self.parallel_config) + scheduler_config = copy.deepcopy(self.scheduler_config) + device_config = copy.deepcopy(self.device_config) + + + for rank, (worker, (node_id, + _)) in enumerate(zip(self.workers, + worker_node_and_gpu_ids)): + local_rank = node_workers[node_id].index(rank) + worker.init_worker.remote( + lambda rank=rank, local_rank=local_rank: Worker( + model_config, + parallel_config, + scheduler_config, + device_config, + local_rank, + rank, + distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + )) + self._run_workers( + "init_model", + get_all_outputs=True, + ) + self._run_workers( + "load_model", + get_all_outputs=True, + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers, + ) + # TODO align + """ + for rank, (worker, (node_id, + _)) in enumerate(zip(self.workers, + worker_node_and_gpu_ids), + start=1): + local_rank = node_workers[node_id].index(rank) + worker.init_worker.remote( + lambda rank=rank, local_rank=local_rank: Worker( + model_config, + parallel_config, + scheduler_config, + device_config, + local_rank, + rank, + distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + )) + driver_rank = 0 + driver_local_rank = node_workers[driver_node_id].index(driver_rank) + self.driver_worker = Worker( + model_config, + parallel_config, + scheduler_config, + device_config, + driver_local_rank, + driver_rank, + distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=True, + ) + # don't use cupy for eager mode + self._run_workers("init_model", + cupy_port=get_open_port() + if not model_config.enforce_eager else None) + self._run_workers( + "load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers, + ) + """ + + def _verify_args(self) -> None: + self.model_config.verify_with_parallel_config(self.parallel_config) + self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) + + def _init_cache(self) -> None: + """Profiles the memory usage and initializes the KV cache. + + The engine will first conduct a profiling of the existing memory usage. + Then, it calculate the maximum possible number of GPU and CPU blocks + that can be allocated with the remaining free memory. + More details can be found in the + :meth:`~vllm.worker.worker.Worker.profile_num_available_blocks` method + from class :class:`~vllm.worker.Worker`. + + Afterwards, as there may be multiple workers, + we take the minimum number of blocks across all workers + to ensure this can be applied to all of them. + + Finally, the engine will initialize the KV cache + with the calculated number of blocks. + + .. tip:: + You may limit the usage of GPU memory + by adjusting the `gpu_memory_utilization` parameters. + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers( + "profile_num_available_blocks", + get_all_outputs=True, + block_size=self.cache_config.block_size, + gpu_memory_utilization=self.cache_config.gpu_memory_utilization, + cpu_swap_space=self.cache_config.swap_space_bytes, + cache_dtype=self.cache_config.cache_dtype, + ) + # TODO align + """ + num_blocks = self._run_workers( + "profile_num_available_blocks", + block_size=self.cache_config.block_size, + gpu_memory_utilization=self.cache_config.gpu_memory_utilization, + cpu_swap_space=self.cache_config.swap_space_bytes, + cache_dtype=self.cache_config.cache_dtype, + ) + """ + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + # FIXME(woosuk): Change to debug log. + logger.info(f"# GPU blocks: {num_gpu_blocks}, " + f"# CPU blocks: {num_cpu_blocks}") + + if num_gpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine.") + max_seq_len = self.cache_config.block_size * num_gpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`gpu_memory_utilization` or decreasing `max_model_len` when " + "initializing the engine.") + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + # Initialize the cache. + self._run_workers("init_cache_engine", cache_config=self.cache_config) + # Warm up the model. This includes capturing the model into CUDA graph + # if enforce_eager is False. + self._run_workers("warm_up_model") + + @classmethod + def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_configs = engine_args.create_engine_configs() + parallel_config = engine_configs[2] + # Initialize the cluster. + placement_group = initialize_cluster(parallel_config) + # Create the LLM engine. + engine = cls(*engine_configs, + placement_group, + log_stats=not engine_args.disable_log_stats) + return engine + + def encode_request( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = self.tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + + def add_request( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + prefix_pos: Optional[int] = None, + ) -> None: + """Add a request to the engine's request pool. + + The request is added to the request pool and will be processed by the + scheduler as `engine.step()` is called. The exact scheduling policy is + determined by the scheduler. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + sampling_params: The sampling parameters for text generation. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + arrival_time: The arrival time of the request. If None, we use + the current monotonic time. + prefix_pos: If not None, we use the given position as the prefix + position for each prompt. We will cache the prefix's KV + cache and reuse it for the next request with the same prefix. + This is an experimental feature, and may be replaced with + automatic prefix caching in the future. + + Details: + - Set arrival_time to the current time if it is None. + - Set prompt_token_ids to the encoded prompt if it is None. + - Create `best_of` number of :class:`~vllm.Sequence` objects. + - Create a :class:`~vllm.SequenceGroup` object + from the list of :class:`~vllm.Sequence`. + - Add the :class:`~vllm.SequenceGroup` object to the scheduler. + + Example: + >>> # initialize engine + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> # set request arguments + >>> example_prompt = "Who is the president of the United States?" + >>> sampling_params = SamplingParams(temperature=0.0) + >>> request_id = 0 + >>> + >>> # add the request to the engine + >>> engine.add_request( + >>> str(request_id), + >>> example_prompt, + >>> SamplingParams(temperature=0.0)) + >>> # continue the request processing + >>> ... + """ + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if arrival_time is None: + arrival_time = time.monotonic() + prompt_token_ids = self.encode_request( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + + # Create the sequences. + block_size = self.cache_config.block_size + seq_id = next(self.seq_counter) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, + lora_request) + + # Check whether the input specifies prefix + prefix = self.scheduler.prefix_pool.add_or_get_prefix( + prompt_token_ids[:prefix_pos], lora_request.lora_int_id + if lora_request else 0) if prefix_pos is not None else None + + # Defensive copy of SamplingParams, which are used by the sampler, + # this doesn't deep-copy LogitsProcessor objects + sampling_params = sampling_params.clone() + + # Create the sequence group. + seq_group = SequenceGroup(request_id, [seq], sampling_params, + arrival_time, lora_request, prefix) + + # Add the sequence group to the scheduler. + self.scheduler.add_seq_group(seq_group) + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + """Aborts a request(s) with the given ID. + + Args: + request_id: The ID(s) of the request to abort. + + Details: + - Refer to the + :meth:`~vllm.core.scheduler.Scheduler.abort_seq_group` + from class :class:`~vllm.core.scheduler.Scheduler`. + + Example: + >>> # initialize engine and add a request with request_id + >>> request_id = str(0) + >>> # abort the request + >>> engine.abort_request(request_id) + """ + self.scheduler.abort_seq_group(request_id) + + def get_model_config(self) -> ModelConfig: + """Gets the model configuration.""" + return self.model_config + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return self.scheduler.get_num_unfinished_seq_groups() + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return self.scheduler.has_unfinished_seqs() + + def _check_beam_search_early_stopping( + self, + early_stopping: Union[bool, str], + sampling_params: SamplingParams, + best_running_seq: Sequence, + current_worst_seq: Sequence, + ) -> bool: + assert sampling_params.use_beam_search + length_penalty = sampling_params.length_penalty + if early_stopping is True: + return True + + current_worst_score = (current_worst_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.get_tokenizer_for_seq( + current_worst_seq).eos_token_id)) + if early_stopping is False: + highest_attainable_score = (best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id)) + else: + assert early_stopping == "never" + if length_penalty > 0.0: + # If length_penalty > 0.0, beam search will prefer longer + # sequences. The highest attainable score calculation is + # based on the longest possible sequence length in this case. + max_possible_length = max( + best_running_seq.get_prompt_len() + + sampling_params.max_tokens, + self.scheduler_config.max_model_len) + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id, + seq_len=max_possible_length)) + else: + # Otherwise, beam search will prefer shorter sequences. The + # highest attainable score calculation is based on the current + # sequence length. + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id)) + return current_worst_score >= highest_attainable_score + + def _process_sequence_group_outputs(self, seq_group: SequenceGroup, + outputs: SequenceGroupOutput) -> None: + + # Process prompt logprobs + prompt_logprobs = outputs.prompt_logprobs + if prompt_logprobs is not None: + seq_group.prompt_logprobs = prompt_logprobs + + # Process samples + samples = outputs.samples + parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + existing_finished_seqs = seq_group.get_finished_seqs() + parent_child_dict = { + parent_seq.seq_id: [] + for parent_seq in parent_seqs + } + for sample in samples: + parent_child_dict[sample.parent_seq_id].append(sample) + # List of (child, parent) + child_seqs: List[Tuple[Sequence, Sequence]] = [] + + # Process the child samples for each parent sequence + for parent in parent_seqs: + child_samples: List[SequenceOutput] = parent_child_dict[ + parent.seq_id] + if len(child_samples) == 0: + # This parent sequence has no children samples. Remove + # the parent sequence from the sequence group since it will + # not be used in the future iterations. + parent.status = SequenceStatus.FINISHED_ABORTED + seq_group.remove(parent.seq_id) + self.scheduler.free_seq(parent) + continue + # Fork the parent sequence if there are multiple child samples. + for child_sample in child_samples[:-1]: + new_child_seq_id = next(self.seq_counter) + child = parent.fork(new_child_seq_id) + child.append_token_id(child_sample.output_token, + child_sample.logprobs) + child_seqs.append((child, parent)) + # Continue the parent sequence for the last child sample. + # We reuse the parent sequence here to reduce redundant memory + # copies, especially when using non-beam search sampling methods. + last_child_sample = child_samples[-1] + parent.append_token_id(last_child_sample.output_token, + last_child_sample.logprobs) + child_seqs.append((parent, parent)) + + for seq, _ in child_seqs: + self._decode_sequence(seq, seq_group.sampling_params) + self._check_stop(seq, seq_group.sampling_params) + + # Non-beam search case + if not seq_group.sampling_params.use_beam_search: + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + # NOTE: we need to fork the new sequences before freeing the + # old sequences. + for seq, parent in child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + return + + # Beam search case + # Select the child sequences to keep in the sequence group. + selected_child_seqs = [] + unselected_child_seqs = [] + beam_width = seq_group.sampling_params.best_of + length_penalty = seq_group.sampling_params.length_penalty + + # Select the newly finished sequences with the highest scores + # to replace existing finished sequences. + # Tuple of (seq, parent, is_new) + existing_finished_seqs = [(seq, None, False) + for seq in existing_finished_seqs] + new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs + if seq.is_finished()] + all_finished_seqs = existing_finished_seqs + new_finished_seqs + # Sort the finished sequences by their scores. + all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), + reverse=True) + for seq, parent, is_new in all_finished_seqs[:beam_width]: + if is_new: + # A newly generated child sequence finishes and has a high + # score, so we will add it into the sequence group. + selected_child_seqs.append((seq, parent)) + for seq, parent, is_new in all_finished_seqs[beam_width:]: + if is_new: + # A newly generated child sequence finishes but has a low + # score, so we will not add it into the sequence group. + # Additionally, if this sequence is a continuation of a + # parent sequence, we will need remove the parent sequence + # from the sequence group. + unselected_child_seqs.append((seq, parent)) + else: + # An existing finished sequence has a low score, so we will + # remove it from the sequence group. + seq_group.remove(seq.seq_id) + + # select the top beam_width sequences from the running + # sequences for the next iteration to continue the beam + # search. + running_child_seqs = [(seq, parent) for seq, parent in child_seqs + if not seq.is_finished()] + # Sort the running sequences by their scores. + running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), + reverse=True) + + # Check if we can stop the beam search. + if len(running_child_seqs) == 0: + # No running sequences, stop the beam search. + stop_beam_search = True + elif len(all_finished_seqs) < beam_width: + # Not enough finished sequences, continue the beam search. + stop_beam_search = False + else: + # Check the early stopping criteria + best_running_seq = running_child_seqs[0][0] + current_worst_seq = all_finished_seqs[beam_width - 1][0] + stop_beam_search = self._check_beam_search_early_stopping( + seq_group.sampling_params.early_stopping, + seq_group.sampling_params, best_running_seq, current_worst_seq) + + if stop_beam_search: + # Stop the beam search and remove all the running sequences from + # the sequence group. + unselected_child_seqs.extend(running_child_seqs) + else: + # Continue the beam search and select the top beam_width sequences + # to continue the beam search. + selected_child_seqs.extend(running_child_seqs[:beam_width]) + # The remaining running sequences will not be used in the next + # iteration. Again, if these sequences are continuations of + # parent sequences, we will need to remove the parent sequences + # from the sequence group. + unselected_child_seqs.extend(running_child_seqs[beam_width:]) + + # For newly created child sequences, add them to the sequence group + # and fork them in block manager if they are not finished. + for seq, parent in selected_child_seqs: + if seq is not parent: + seq_group.add(seq) + if not seq.is_finished(): + self.scheduler.fork_seq(parent, seq) + + # Free the finished and selected parent sequences' memory in block + # manager. Keep them in the sequence group as candidate output. + for seq, parent in selected_child_seqs: + if seq is parent and seq.is_finished(): + self.scheduler.free_seq(seq) + + # Remove the unselected parent sequences from the sequence group and + # free their memory in block manager. + for seq, parent in unselected_child_seqs: + if seq is parent: + # Remove the parent sequence if it is not selected for next + # iteration + seq_group.remove(seq.seq_id) + self.scheduler.free_seq(seq) + + def _process_model_outputs( + self, output: SamplerOutput, + scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: + now = time.time() + # Update the scheduled sequence groups with the model outputs. + scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups + for seq_group, outputs in zip(scheduled_seq_groups, output): + self._process_sequence_group_outputs(seq_group, outputs) + + # Free the finished sequence groups. + self.scheduler.free_finished_seq_groups() + + # Create the outputs. + request_outputs: List[RequestOutput] = [] + for seq_group in scheduled_seq_groups: + seq_group.maybe_set_first_token_time(now) + request_output = RequestOutput.from_seq_group(seq_group) + request_outputs.append(request_output) + for seq_group in scheduler_outputs.ignored_seq_groups: + request_output = RequestOutput.from_seq_group(seq_group) + request_outputs.append(request_output) + + # Update prefix state, now all the uncomputed prefixes are computed. + for seq_group in scheduled_seq_groups: + if (seq_group.prefix is not None and seq_group.prefix.allocated + and not seq_group.prefix.computed): + seq_group.prefix.computed = True + + # Log stats. + if self.log_stats: + self.stat_logger.log(self._get_stats(scheduler_outputs)) + + return request_outputs + + def step(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + + .. figure:: https://i.imgur.com/sv2HssD.png + :alt: Overview of the step function + :align: center + + Overview of the step function. + + Details: + - Step 1: Schedules the sequences to be executed in the next + iteration and the token blocks to be swapped in/out/copy. + + - Depending on the scheduling policy, + sequences may be `preempted/reordered`. + - A Sequence Group (SG) refer to a group of sequences + that are generated from the same prompt. + + - Step 2: Calls the workers to execute the model. + - Step 3: Processes the model output. This mainly includes: + + - Decodes the relevant outputs. + - Updates the scheduled sequence groups with model outputs + based on its `sampling parameters` (`use_beam_search` or not). + - Frees the finished sequence groups. + + - Finally, it creates and returns the newly generated results. + + Example: + >>> # Please see the example/ folder for more detailed examples. + >>> + >>> # initialize engine and request arguments + >>> engine = LLMEngine.from_engine_args(engine_args) + >>> example_inputs = [(0, "What is LLM?", + >>> SamplingParams(temperature=0.0))] + >>> + >>> # Start the engine with an event loop + >>> while True: + >>> if example_inputs: + >>> req_id, prompt, sampling_params = example_inputs.pop(0) + >>> engine.add_request(str(req_id), prompt, sampling_params) + >>> + >>> # continue the request processing + >>> request_outputs = engine.step() + >>> for request_output in request_outputs: + >>> if request_output.finished: + >>> # return or show the request output + >>> + >>> if not (engine.has_unfinished_requests() or example_inputs): + >>> break + """ + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + + output = self._run_workers( + "execute_model", + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in, + blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out, + blocks_to_copy=scheduler_outputs.blocks_to_copy, + ) if not scheduler_outputs.is_empty() else [] + + return self._process_model_outputs(output, scheduler_outputs) + # TODO align + """ + if not scheduler_outputs.is_empty(): + # Execute the model. + all_outputs = self._run_workers( + "execute_model", + driver_kwargs={ + "seq_group_metadata_list": seq_group_metadata_list, + "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, + "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, + "blocks_to_copy": scheduler_outputs.blocks_to_copy, + }, + use_ray_compiled_dag=USE_RAY_COMPILED_DAG) + + # Only the driver worker returns the sampling results. + output = all_outputs[0] + else: + output = [] + + return self._process_model_outputs(output, scheduler_outputs) + """ + + def do_log_stats(self) -> None: + """Forced log when no requests active.""" + if self.log_stats: + self.stat_logger.log(self._get_stats(scheduler_outputs=None)) + + def _get_stats(self, + scheduler_outputs: Optional[SchedulerOutputs]) -> Stats: + """Get Stats to be Logged to Prometheus.""" + now = time.monotonic() + + # KV Cache Usage in %. + num_total_gpu = self.cache_config.num_gpu_blocks + num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() + gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu) + + num_total_cpu = self.cache_config.num_cpu_blocks + cpu_cache_usage = 0. + if num_total_cpu > 0: + num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( + ) + cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu) + + # Scheduler State + num_running = len(self.scheduler.running) + num_swapped = len(self.scheduler.swapped) + num_waiting = len(self.scheduler.waiting) + + # Iteration stats if we have scheduler output. + num_prompt_tokens = 0 + num_generation_tokens = 0 + time_to_first_tokens = [] + time_per_output_tokens = [] + time_e2e_requests = [] + if scheduler_outputs is not None: + prompt_run = scheduler_outputs.prompt_run + + # Number of Tokens. + if prompt_run: + num_prompt_tokens = sum( + len(seq_group.prompt_token_ids) + for seq_group in scheduler_outputs.scheduled_seq_groups) + num_generation_tokens = sum( + seq_group.num_seqs() + for seq_group in scheduler_outputs.scheduled_seq_groups) + else: + num_generation_tokens = scheduler_outputs.num_batched_tokens + + # Latency Timings. + time_last_iters = [] + for seq_group in scheduler_outputs.scheduled_seq_groups: + # Time since last token. (n.b. updates seq_group.metrics.last_token_time) + time_last_iters.append(seq_group.get_last_latency(now)) + # Time since arrival for all finished requests. + if seq_group.is_finished(): + time_e2e_requests.append(now - + seq_group.metrics.arrival_time) + + time_to_first_tokens = time_last_iters if prompt_run else [] + time_per_output_tokens = [] if prompt_run else time_last_iters + + return Stats( + now=now, + num_running=num_running, + num_swapped=num_swapped, + num_waiting=num_waiting, + gpu_cache_usage=gpu_cache_usage, + cpu_cache_usage=cpu_cache_usage, + num_prompt_tokens=num_prompt_tokens, + num_generation_tokens=num_generation_tokens, + time_to_first_tokens=time_to_first_tokens, + time_per_output_tokens=time_per_output_tokens, + time_e2e_requests=time_e2e_requests, + ) + + def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: + """Decodes the new token for a sequence.""" + (new_tokens, new_output_text, prefix_offset, + read_offset) = detokenize_incrementally( + self.get_tokenizer_for_seq(seq), + all_input_ids=seq.get_token_ids(), + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms.spaces_between_special_tokens, + ) + if seq.tokens is None: + seq.tokens = new_tokens + else: + seq.tokens.extend(new_tokens) + seq.prefix_offset = prefix_offset + seq.read_offset = read_offset + seq.output_text += new_output_text + + def _check_stop(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + """Stop the finished sequences.""" + for stop_str in sampling_params.stop: + if seq.output_text.endswith(stop_str): + self._finalize_sequence(seq, sampling_params, stop_str) + seq.status = SequenceStatus.FINISHED_STOPPED + return + if seq.get_last_token_id() in sampling_params.stop_token_ids: + stop_str = self.get_tokenizer_for_seq(seq).convert_ids_to_tokens( + seq.get_last_token_id()) + self._finalize_sequence(seq, sampling_params, stop_str) + seq.status = SequenceStatus.FINISHED_STOPPED + return + + # Check if the sequence has reached max_model_len. + if seq.get_len() > self.scheduler_config.max_model_len: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has reached max_tokens. + if seq.get_output_len() == sampling_params.max_tokens: + seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED + return + + # Check if the sequence has generated the EOS token. + if ((not sampling_params.ignore_eos) and seq.get_last_token_id() + == self.get_tokenizer_for_seq(seq).eos_token_id): + seq.status = SequenceStatus.FINISHED_STOPPED + return + + def _finalize_sequence(self, seq: Sequence, + sampling_params: SamplingParams, + stop_string: str) -> None: + if sampling_params.include_stop_str_in_output: + return + + if stop_string and seq.output_text.endswith(stop_string): + # Truncate the output text so that the stop string is + # not included in the output. + seq.output_text = seq.output_text[:-len(stop_string)] + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> List[int]: + return self._run_workers("list_loras") + + def _run_workers_in_batch( + self, + workers, + method: str, + *args, + **kwargs, + ): + all_outputs = [] + for worker in workers: + if self.parallel_config.worker_use_ray: + executor = partial(worker.execute_method.remote, method) + else: + executor = getattr(worker, method) + + output = executor(*args, **kwargs) + all_outputs.append(output) + if self.parallel_config.worker_use_ray: + all_outputs = ray.get(all_outputs) + return all_outputs + + def _run_workers( + self, + method: str, + *args, + get_all_outputs: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + all_outputs = [] + if max_concurrent_workers: + work_groups = [ + self.workers[i:i + max_concurrent_workers] + for i in range(0, len(self.workers), max_concurrent_workers) + ] + else: + work_groups = [self.workers] + + for workers in work_groups: + all_outputs.extend( + self._run_workers_in_batch(workers, method, *args, **kwargs)) + + if get_all_outputs: + return all_outputs + + # Make sure all workers have the same results. + output = all_outputs[0] + for other_output in all_outputs[1:]: + assert output == other_output + return output + + # TODO align + """ + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[List[Any]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + use_ray_compiled_dag: bool = False, + **kwargs, + ) -> Any: + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + if use_ray_compiled_dag: + # Right now, compiled DAG can only accept a single + # input. TODO(sang): Fix it. + output_channels = self.forward_dag.execute(1) + else: + # Start the ray workers first. + ray_worker_outputs = [ + worker.execute_method.remote(method, *args, **kwargs) + for worker in self.workers + ] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Start the driver worker after all the ray workers. + driver_worker_output = getattr(self.driver_worker, + method)(*driver_args, **driver_kwargs) + + # Get the results of the ray workers. + if self.workers: + if use_ray_compiled_dag: + try: + ray_worker_outputs = [ + pickle.loads(chan.begin_read()) + for chan in output_channels + ] + finally: + # Has to call end_read in order to reuse the DAG. + for chan in output_channels: + chan.end_read() + else: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return [driver_worker_output] + ray_worker_outputs + + def _compiled_ray_dag(self): + import pkg_resources + required_version = "2.9" + current_version = pkg_resources.get_distribution("ray").version + if current_version < required_version: + raise ValueError(f"Ray version {required_version} or greater is " + f"required, but found {current_version}") + + from ray.dag import MultiOutputNode, InputNode + assert self.parallel_config.worker_use_ray + + # Right now, compiled DAG requires at least 1 arg. We send + # a dummy value for now. It will be fixed soon. + with InputNode() as input_data: + forward_dag = MultiOutputNode([ + worker.execute_model_compiled_dag_remote.bind(input_data) + for worker in self.workers + ]) + return forward_dag.experimental_compile() + """ \ No newline at end of file diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py new file mode 100644 index 0000000..54b09c3 --- /dev/null +++ b/vllm/engine/metrics.py @@ -0,0 +1,225 @@ +from vllm.logger import init_logger +from prometheus_client import Counter, Gauge, Histogram, Info, REGISTRY, disable_created_metrics + +import time +import numpy as np +from typing import Dict, List +from dataclasses import dataclass + +logger = init_logger(__name__) + +disable_created_metrics() + +# The begin-* and end* here are used by the documentation generator +# to extract the metrics definitions. + + +# begin-metrics-definitions +class Metrics: + + def __init__(self, labelnames: List[str]): + # Unregister any existing vLLM collectors + for collector in list(REGISTRY._collector_to_names): + if hasattr(collector, "_name") and "vllm" in collector._name: + REGISTRY.unregister(collector) + + self.info_cache_config = Info( + name='vllm:cache_config', + documentation='information of cache_config') + + # System stats + self.gauge_scheduler_running = Gauge( + name="vllm:num_requests_running", + documentation="Number of requests currently running on GPU.", + labelnames=labelnames) + self.gauge_scheduler_swapped = Gauge( + name="vllm:num_requests_swapped", + documentation="Number of requests swapped to CPU.", + labelnames=labelnames) + self.gauge_scheduler_waiting = Gauge( + name="vllm:num_requests_waiting", + documentation="Number of requests waiting to be processed.", + labelnames=labelnames) + self.gauge_gpu_cache_usage = Gauge( + name="vllm:gpu_cache_usage_perc", + documentation="GPU KV-cache usage. 1 means 100 percent usage.", + labelnames=labelnames) + self.gauge_cpu_cache_usage = Gauge( + name="vllm:cpu_cache_usage_perc", + documentation="CPU KV-cache usage. 1 means 100 percent usage.", + labelnames=labelnames) + + # Raw stats from last model iteration + self.counter_prompt_tokens = Counter( + name="vllm:prompt_tokens_total", + documentation="Number of prefill tokens processed.", + labelnames=labelnames) + self.counter_generation_tokens = Counter( + name="vllm:generation_tokens_total", + documentation="Number of generation tokens processed.", + labelnames=labelnames) + self.histogram_time_to_first_token = Histogram( + name="vllm:time_to_first_token_seconds", + documentation="Histogram of time to first token in seconds.", + labelnames=labelnames, + buckets=[ + 0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5, + 0.75, 1.0, 2.5, 5.0, 7.5, 10.0 + ]) + self.histogram_time_per_output_token = Histogram( + name="vllm:time_per_output_token_seconds", + documentation="Histogram of time per output token in seconds.", + labelnames=labelnames, + buckets=[ + 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, + 1.0, 2.5 + ]) + self.histogram_e2e_request_latency = Histogram( + name="vllm:e2e_request_latency_seconds", + documentation="Histogram of end to end request latency in seconds.", + labelnames=labelnames, + buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) + + # Legacy metrics + self.gauge_avg_prompt_throughput = Gauge( + name="vllm:avg_prompt_throughput_toks_per_s", + documentation="Average prefill throughput in tokens/s.", + labelnames=labelnames, + ) + self.gauge_avg_generation_throughput = Gauge( + name="vllm:avg_generation_throughput_toks_per_s", + documentation="Average generation throughput in tokens/s.", + labelnames=labelnames, + ) + + +# end-metrics-definitions + + +@dataclass +class Stats: + """Created by LLMEngine for use by StatLogger.""" + now: float + + # System stats. + num_running: int + num_waiting: int + num_swapped: int + gpu_cache_usage: float + cpu_cache_usage: float + + # Raw stats from last model iteration. + num_prompt_tokens: int + num_generation_tokens: int + time_to_first_tokens: List[float] + time_per_output_tokens: List[float] + time_e2e_requests: List[float] + + +class StatLogger: + """StatLogger is used LLMEngine to log to Promethus and Stdout.""" + + def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: + # Metadata for logging locally. + self.last_local_log = time.monotonic() + self.local_interval = local_interval + + # Tracked stats over current local logging interval. + self.num_prompt_tokens: List[int] = [] + self.num_generation_tokens: List[int] = [] + + # Prometheus metrics + self.labels = labels + self.metrics = Metrics(labelnames=list(labels.keys())) + + def info(self, type: str, obj: object) -> None: + if type == "cache_config": + self.metrics.info_cache_config.info(obj.metrics_info()) + + def _get_throughput(self, tracked_stats: List[int], now: float) -> float: + return float(np.sum(tracked_stats) / (now - self.last_local_log)) + + def _local_interval_elapsed(self, now: float) -> bool: + elapsed_time = now - self.last_local_log + return elapsed_time > self.local_interval + + def _log_prometheus(self, stats: Stats) -> None: + # Set system stat gauges. + self.metrics.gauge_scheduler_running.labels(**self.labels).set( + stats.num_running) + self.metrics.gauge_scheduler_swapped.labels(**self.labels).set( + stats.num_swapped) + self.metrics.gauge_scheduler_waiting.labels(**self.labels).set( + stats.num_waiting) + self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set( + stats.gpu_cache_usage) + self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set( + stats.cpu_cache_usage) + + # Add to token counters. + self.metrics.counter_prompt_tokens.labels(**self.labels).inc( + stats.num_prompt_tokens) + self.metrics.counter_generation_tokens.labels(**self.labels).inc( + stats.num_generation_tokens) + + # Observe request level latencies in histograms. + for ttft in stats.time_to_first_tokens: + self.metrics.histogram_time_to_first_token.labels( + **self.labels).observe(ttft) + for tpot in stats.time_per_output_tokens: + self.metrics.histogram_time_per_output_token.labels( + **self.labels).observe(tpot) + for e2e in stats.time_e2e_requests: + self.metrics.histogram_e2e_request_latency.labels( + **self.labels).observe(e2e) + + def _log_prometheus_interval(self, prompt_throughput: float, + generation_throughput: float) -> None: + # Logs metrics to prometheus that are computed every logging_interval. + # Support legacy gauge metrics that make throughput calculations on the vLLM side. + # Moving forward, we should use counters like counter_prompt_tokens, counter_generation_tokens + # Which log raw data and calculate summaries using rate() on the grafana/prometheus side. + # See https://github.com/vllm-project/vllm/pull/2316#discussion_r1464204666 + self.metrics.gauge_avg_prompt_throughput.labels( + **self.labels).set(prompt_throughput) + self.metrics.gauge_avg_generation_throughput.labels( + **self.labels).set(generation_throughput) + + def log(self, stats: Stats) -> None: + """Called by LLMEngine. + Logs to prometheus and tracked stats every iteration. + Logs to Stdout every self.local_interval seconds.""" + + # Log to prometheus. + self._log_prometheus(stats) + + # Save tracked stats for token counters. + self.num_prompt_tokens.append(stats.num_prompt_tokens) + self.num_generation_tokens.append(stats.num_generation_tokens) + + # Log locally every local_interval seconds. + if self._local_interval_elapsed(stats.now): + + # Compute summary metrics for tracked stats (and log them to promethus if applicable). + prompt_throughput = self._get_throughput(self.num_prompt_tokens, + now=stats.now) + generation_throughput = self._get_throughput( + self.num_generation_tokens, now=stats.now) + self._log_prometheus_interval( + prompt_throughput=prompt_throughput, + generation_throughput=generation_throughput) + + # Log to stdout. + logger.info( + f"Avg prompt throughput: {prompt_throughput:.1f} tokens/s, " + f"Avg generation throughput: {generation_throughput:.1f} tokens/s, " + f"Running: {stats.num_running} reqs, " + f"Swapped: {stats.num_swapped} reqs, " + f"Pending: {stats.num_waiting} reqs, " + f"GPU KV cache usage: {stats.gpu_cache_usage * 100:.1f}%, " + f"CPU KV cache usage: {stats.cpu_cache_usage * 100:.1f}%") + + # Reset tracked stats for next interval. + self.num_prompt_tokens = [] + self.num_generation_tokens = [] + self.last_local_log = stats.now diff --git a/vllm/engine/ray_utils.py b/vllm/engine/ray_utils.py new file mode 100644 index 0000000..fb4e770 --- /dev/null +++ b/vllm/engine/ray_utils.py @@ -0,0 +1,157 @@ +import pickle + +from typing import Optional, List, Tuple, TYPE_CHECKING + +from vllm.config import ParallelConfig +from vllm.logger import init_logger +from vllm.utils import is_hip, set_cuda_visible_devices, get_ip + +logger = init_logger(__name__) + +try: + import ray + + class RayWorkerVllm: + """Ray wrapper for vllm.worker.Worker, allowing Worker to be + lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" + + def __init__(self, init_cached_hf_modules=False) -> None: + if init_cached_hf_modules: + from transformers.dynamic_module_utils import init_hf_modules + init_hf_modules() + self.worker = None + # Since the compiled DAG runs a main execution + # in a different thread that calls cuda.set_device. + # The flag indicates is set_device is called on + # that thread. + self.compiled_dag_cuda_device_set = False + + def init_worker(self, worker_init_fn): + self.worker = worker_init_fn() + + def __getattr__(self, name): + return getattr(self.worker, name) + + def execute_method(self, method, *args, **kwargs): + executor = getattr(self, method) + return executor(*args, **kwargs) + + def get_node_ip(self) -> str: + return get_ip() + + def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: + node_id = ray.get_runtime_context().get_node_id() + gpu_ids = ray.get_gpu_ids() + return node_id, gpu_ids + + def set_cuda_visible_devices(self, device_ids) -> None: + set_cuda_visible_devices(device_ids) + + def execute_model_compiled_dag_remote(self, ignored): + """Used only when compiled DAG is enabled.""" + import torch + if not self.compiled_dag_cuda_device_set: + torch.cuda.set_device(self.worker.device) + self.compiled_dag_cuda_device_set = True + + output = self.worker.execute_model() + output = pickle.dumps(output) + return output + +except ImportError as e: + logger.warning(f"Failed to import Ray with {e!r}. " + "For distributed inference, please install Ray with " + "`pip install ray`.") + ray = None + RayWorkerVllm = None + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + + +def initialize_cluster( + parallel_config: ParallelConfig, + engine_use_ray: bool = False, + ray_address: Optional[str] = None, +) -> Optional["PlacementGroup"]: + """Initialize the distributed cluster probably with Ray. + + Args: + parallel_config: The configurations for parallel execution. + engine_use_ray: Whether to use Ray for async engine. + ray_address: The address of the Ray cluster. If None, uses + the default Ray cluster address. + + Returns: + An optional `PlacementGroup`. It includes the specification + of the resources for each distributed worker. None if Ray is + not used. + """ + if parallel_config.worker_use_ray or engine_use_ray: + if ray is None: + raise ImportError( + "Ray is not installed. Please install Ray to use distributed " + "serving.") + import os + enable_head_ray = os.environ.get("ENABLE_HEAD_RAY",None) + if enable_head_ray is None: + if is_hip(): + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) + else: + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) + else: + ray.init() + # TODO align + """ + # Connect to a ray cluster. + if is_hip(): + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) + else: + ray.init(address=ray_address, ignore_reinit_error=True) + """ + + if not parallel_config.worker_use_ray: + assert parallel_config.world_size == 1, ( + "Ray is required if parallel_config.world_size > 1.") + return None + + # Create placement group for worker processes + current_placement_group = ray.util.get_current_placement_group() + if current_placement_group: + # We are in a placement group + bundles = current_placement_group.bundle_specs + # Verify that we can use the placement group. + gpu_bundles = 0 + for bundle in bundles: + bundle_gpus = bundle.get("GPU", 0) + if bundle_gpus > 1: + raise ValueError( + "Placement group bundle cannot have more than 1 GPU.") + if bundle_gpus: + gpu_bundles += 1 + if parallel_config.world_size > gpu_bundles: + raise ValueError( + "The number of required GPUs exceeds the total number of " + "available GPUs in the placement group.") + else: + num_gpus_in_cluster = ray.cluster_resources().get("GPU", 0) + if parallel_config.world_size > num_gpus_in_cluster: + raise ValueError( + "The number of required GPUs exceeds the total number of " + "available GPUs in the cluster.") + # Create a new placement group + placement_group_specs = ([{"GPU": 1}] * parallel_config.world_size) + current_placement_group = ray.util.placement_group( + placement_group_specs) + # Wait until PG is ready - this will block until all + # requested resources are available, and will timeout + # if they cannot be provisioned. + ray.get(current_placement_group.ready(), timeout=1800) + + return current_placement_group diff --git a/vllm/entrypoints/__init__.py b/vllm/entrypoints/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py new file mode 100644 index 0000000..e7af2c6 --- /dev/null +++ b/vllm/entrypoints/api_server.py @@ -0,0 +1,105 @@ +""" +NOTE: This API server is used only for demonstrating usage of AsyncEngine and simple performance benchmarks. +It is not intended for production use. For production use, we recommend using our OpenAI compatible server. +We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead. +""" + +import argparse +import json +from typing import AsyncGenerator + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse +import uvicorn + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.sampling_params import SamplingParams +from vllm.utils import random_uuid + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +app = FastAPI() +engine = None + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.post("/generate") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + request_dict = await request.json() + prompt = request_dict.pop("prompt") + prefix_pos = request_dict.pop("prefix_pos", None) + stream = request_dict.pop("stream", False) + sampling_params = SamplingParams(**request_dict) + request_id = random_uuid() + + results_generator = engine.generate(prompt, + sampling_params, + request_id, + prefix_pos=prefix_pos) + + # Streaming case + async def stream_results() -> AsyncGenerator[bytes, None]: + async for request_output in results_generator: + prompt = request_output.prompt + text_outputs = [ + prompt + output.text for output in request_output.outputs + ] + ret = {"text": text_outputs} + yield (json.dumps(ret) + "\0").encode("utf-8") + + if stream: + return StreamingResponse(stream_results()) + + # Non-streaming case + final_output = None + async for request_output in results_generator: + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await engine.abort(request_id) + return Response(status_code=499) + final_output = request_output + + assert final_output is not None + prompt = final_output.prompt + text_outputs = [prompt + output.text for output in final_output.outputs] + ret = {"text": text_outputs} + return JSONResponse(ret) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--ssl-keyfile", type=str, default=None) + parser.add_argument("--ssl-certfile", type=str, default=None) + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + + app.root_path = args.root_path + uvicorn.run(app, + host=args.host, + port=args.port, + log_level="debug", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile) diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py new file mode 100644 index 0000000..fc82018 --- /dev/null +++ b/vllm/entrypoints/llm.py @@ -0,0 +1,220 @@ +from typing import List, Optional, Union + +from tqdm import tqdm +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from vllm.lora.request import LoRARequest +from vllm.engine.arg_utils import EngineArgs +from vllm.engine.llm_engine import LLMEngine +from vllm.outputs import RequestOutput +from vllm.sampling_params import SamplingParams +from vllm.utils import Counter + + +class LLM: + """An LLM for generating texts from given prompts and sampling parameters. + + This class includes a tokenizer, a language model (possibly distributed + across multiple GPUs), and GPU memory space allocated for intermediate + states (aka KV cache). Given a batch of prompts and sampling parameters, + this class generates texts from the model, using an intelligent batching + mechanism and efficient memory management. + + NOTE: This class is intended to be used for offline inference. For online + serving, use the `AsyncLLMEngine` class instead. + NOTE: For the comprehensive list of arguments, see `EngineArgs`. + + Args: + model: The name or path of a HuggingFace Transformers model. + tokenizer: The name or path of a HuggingFace Transformers tokenizer. + tokenizer_mode: The tokenizer mode. "auto" will use the fast tokenizer + if available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + tensor_parallel_size: The number of GPUs to use for distributed + execution with tensor parallelism. + dtype: The data type for the model weights and activations. Currently, + we support `float32`, `float16`, and `bfloat16`. If `auto`, we use + the `torch_dtype` attribute specified in the model config file. + However, if the `torch_dtype` in the config is `float32`, we will + use `float16` instead. + quantization: The method used to quantize the model weights. Currently, + we support "awq", "gptq" and "squeezellm". If None, we first check + the `quantization_config` attribute in the model config file. If + that is None, we assume the model weights are not quantized and use + `dtype` to determine the data type of the weights. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. + seed: The seed to initialize the random number generator for sampling. + gpu_memory_utilization: The ratio (between 0 and 1) of GPU memory to + reserve for the model weights, activations, and KV cache. Higher + values will increase the KV cache size and thus improve the model's + throughput. However, if the value is too high, it may cause out-of- + memory (OOM) errors. + swap_space: The size (GiB) of CPU memory per GPU to use as swap space. + This can be used for temporarily storing the states of the requests + when their `best_of` sampling parameters are larger than 1. If all + requests will have `best_of=1`, you can safely set this to 0. + Otherwise, too small values may cause out-of-memory (OOM) errors. + enforce_eager: Whether to enforce eager execution. If True, we will + disable CUDA graph and always execute the model in eager mode. + If False, we will use CUDA graph and eager execution in hybrid. + max_context_len_to_capture: Maximum context len covered by CUDA graphs. + When a sequence has context length larger than this, we fall back + to eager mode. + disable_custom_all_reduce: See ParallelConfig + """ + + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + enforce_eager: bool = False, + max_context_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + engine_args = EngineArgs( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + **kwargs, + ) + self.llm_engine = LLMEngine.from_engine_args(engine_args) + self.request_counter = Counter() + + def get_tokenizer( + self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + return self.llm_engine.tokenizer.tokenizer + + def set_tokenizer( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + ) -> None: + self.llm_engine.tokenizer.tokenizer = tokenizer + + def generate( + self, + prompts: Optional[Union[str, List[str]]] = None, + sampling_params: Optional[SamplingParams] = None, + prompt_token_ids: Optional[List[List[int]]] = None, + prefix_pos: Optional[Union[int, List[int]]] = None, + use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, + ) -> List[RequestOutput]: + """Generates the completions for the input prompts. + + NOTE: This class automatically batches the given prompts, considering + the memory constraint. For the best performance, put all of your prompts + into a single list and pass it to this method. + + Args: + prompts: A list of prompts to generate completions for. + sampling_params: The sampling parameters for text generation. If + None, we use the default sampling parameters. + prompt_token_ids: A list of token IDs for the prompts. If None, we + use the tokenizer to convert the prompts to token IDs. + prefix_pos: If not None, we use the given position as the prefix + position for each prompt. We will cache the prefix's KV + cache and reuse it for the next request with the same prefix. + This is an experimental feature, and may be replaced with + automatic prefix caching in the future. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + + Returns: + A list of `RequestOutput` objects containing the generated + completions in the same order as the input prompts. + """ + if prompts is None and prompt_token_ids is None: + raise ValueError("Either prompts or prompt_token_ids must be " + "provided.") + if isinstance(prompts, str): + # Convert a single prompt to a list. + prompts = [prompts] + if (prompts is not None and prompt_token_ids is not None + and len(prompts) != len(prompt_token_ids)): + raise ValueError("The lengths of prompts and prompt_token_ids " + "must be the same.") + if sampling_params is None: + # Use default sampling params. + sampling_params = SamplingParams() + + # Add requests to the engine. + num_requests = len(prompts) if prompts is not None else len( + prompt_token_ids) + for i in range(num_requests): + prompt = prompts[i] if prompts is not None else None + prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None + token_ids = None if prompt_token_ids is None else prompt_token_ids[ + i] + self._add_request(prompt, + sampling_params, + token_ids, + lora_request=lora_request, + prefix_pos=prefix_pos_i) + return self._run_engine(use_tqdm) + + def _add_request( + self, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]], + lora_request: Optional[LoRARequest] = None, + prefix_pos: Optional[int] = None, + ) -> None: + request_id = str(next(self.request_counter)) + self.llm_engine.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids, + lora_request=lora_request, + prefix_pos=prefix_pos) + + def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + pbar = tqdm(total=num_requests, desc="Processed prompts") + # Run the engine. + outputs: List[RequestOutput] = [] + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + pbar.update(1) + if use_tqdm: + pbar.close() + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + outputs = sorted(outputs, key=lambda x: int(x.request_id)) + return outputs diff --git a/vllm/entrypoints/openai/__init__.py b/vllm/entrypoints/openai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py new file mode 100644 index 0000000..b2f0401 --- /dev/null +++ b/vllm/entrypoints/openai/api_server.py @@ -0,0 +1,251 @@ +import argparse +import asyncio +import json +from contextlib import asynccontextmanager +import os +import importlib +import inspect + +from prometheus_client import make_asgi_app +import fastapi +import uvicorn +from http import HTTPStatus +from fastapi import Request +from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, StreamingResponse, Response + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse +from vllm.logger import init_logger +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_engine import LoRA + +TIMEOUT_KEEP_ALIVE = 5 # seconds + +openai_serving_chat: OpenAIServingChat = None +openai_serving_completion: OpenAIServingCompletion = None +logger = init_logger(__name__) + + +@asynccontextmanager +async def lifespan(app: fastapi.FastAPI): + + async def _force_log(): + while True: + await asyncio.sleep(10) + await engine.do_log_stats() + + if not engine_args.disable_log_stats: + asyncio.create_task(_force_log()) + + yield + + +app = fastapi.FastAPI(lifespan=lifespan) + + +class LoRAParserAction(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + lora_list = [] + for item in values: + name, path = item.split('=') + lora_list.append(LoRA(name, path)) + setattr(namespace, self.dest, lora_list) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="vLLM OpenAI-Compatible RESTful API server.") + parser.add_argument("--host", type=str, default=None, help="host name") + parser.add_argument("--port", type=int, default=8000, help="port number") + parser.add_argument("--allow-credentials", + action="store_true", + help="allow credentials") + parser.add_argument("--allowed-origins", + type=json.loads, + default=["*"], + help="allowed origins") + parser.add_argument("--allowed-methods", + type=json.loads, + default=["*"], + help="allowed methods") + parser.add_argument("--allowed-headers", + type=json.loads, + default=["*"], + help="allowed headers") + parser.add_argument( + "--api-key", + type=str, + default=None, + help= + "If provided, the server will require this key to be presented in the header." + ) + parser.add_argument("--served-model-name", + type=str, + default=None, + help="The model name used in the API. If not " + "specified, the model name will be the same as " + "the huggingface name.") + parser.add_argument( + "--lora-modules", + type=str, + default=None, + nargs='+', + action=LoRAParserAction, + help= + "LoRA module configurations in the format name=path. Multiple modules can be specified." + ) + parser.add_argument("--chat-template", + type=str, + default=None, + help="The file path to the chat template, " + "or the template in single-line form " + "for the specified model") + parser.add_argument("--response-role", + type=str, + default="assistant", + help="The role name to return if " + "`request.add_generation_prompt=true`.") + parser.add_argument("--ssl-keyfile", + type=str, + default=None, + help="The file path to the SSL key file") + parser.add_argument("--ssl-certfile", + type=str, + default=None, + help="The file path to the SSL cert file") + parser.add_argument( + "--root-path", + type=str, + default=None, + help="FastAPI root_path when app is behind a path based routing proxy") + parser.add_argument( + "--middleware", + type=str, + action="append", + default=[], + help="Additional ASGI middleware to apply to the app. " + "We accept multiple --middleware arguments. " + "The value should be an import path. " + "If a function is provided, vLLM will add it to the server using @app.middleware('http'). " + "If a class is provided, vLLM will add it to the server using app.add_middleware(). " + ) + + parser = AsyncEngineArgs.add_cli_args(parser) + return parser.parse_args() + + +# Add prometheus asgi middleware to route /metrics requests +metrics_app = make_asgi_app() +app.mount("/metrics", metrics_app) + + +@app.exception_handler(RequestValidationError) +async def validation_exception_handler(_, exc): + err = openai_serving_chat.create_error_response(message=str(exc)) + return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) + + +@app.get("/health") +async def health() -> Response: + """Health check.""" + return Response(status_code=200) + + +@app.get("/v1/models") +async def show_available_models(): + models = await openai_serving_chat.show_available_models() + return JSONResponse(content=models.model_dump()) + + +@app.post("/v1/chat/completions") +async def create_chat_completion(request: ChatCompletionRequest, + raw_request: Request): + generator = await openai_serving_chat.create_chat_completion( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, + media_type="text/event-stream") + else: + return JSONResponse(content=generator.model_dump()) + + +@app.post("/v1/completions") +async def create_completion(request: CompletionRequest, raw_request: Request): + generator = await openai_serving_completion.create_completion( + request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, + media_type="text/event-stream") + else: + return JSONResponse(content=generator.model_dump()) + + +if __name__ == "__main__": + args = parse_args() + + app.add_middleware( + CORSMiddleware, + allow_origins=args.allowed_origins, + allow_credentials=args.allow_credentials, + allow_methods=args.allowed_methods, + allow_headers=args.allowed_headers, + ) + + if token := os.environ.get("VLLM_API_KEY") or args.api_key: + + @app.middleware("http") + async def authentication(request: Request, call_next): + if not request.url.path.startswith("/v1"): + return await call_next(request) + if request.headers.get("Authorization") != "Bearer " + token: + return JSONResponse(content={"error": "Unauthorized"}, + status_code=401) + return await call_next(request) + + for middleware in args.middleware: + module_path, object_name = middleware.rsplit(".", 1) + imported = getattr(importlib.import_module(module_path), object_name) + if inspect.isclass(imported): + app.add_middleware(imported) + elif inspect.iscoroutinefunction(imported): + app.middleware("http")(imported) + else: + raise ValueError( + f"Invalid middleware {middleware}. Must be a function or a class." + ) + + logger.info(f"args: {args}") + + if args.served_model_name is not None: + served_model = args.served_model_name + else: + served_model = args.model + + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) + openai_serving_chat = OpenAIServingChat(engine, served_model, + args.response_role, + args.lora_modules, + args.chat_template) + openai_serving_completion = OpenAIServingCompletion( + engine, served_model, args.lora_modules) + + app.root_path = args.root_path + uvicorn.run(app, + host=args.host, + port=args.port, + log_level="info", + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py new file mode 100644 index 0000000..26499b8 --- /dev/null +++ b/vllm/entrypoints/openai/protocol.py @@ -0,0 +1,323 @@ +# Adapted from +# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import time +from typing import Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, Field, model_validator + +from vllm.utils import random_uuid +from vllm.sampling_params import SamplingParams + +import torch + + +class ErrorResponse(BaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + +class ModelPermission(BaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = False + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: str = False + + +class ModelCard(BaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "vllm" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = Field(default_factory=list) + + +class ModelList(BaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + + +class UsageInfo(BaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class ChatCompletionRequest(BaseModel): + model: str + messages: List[Dict[str, str]] + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + max_tokens: Optional[int] = None + seed: Optional[int] = None + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + # Additional parameters supported by vLLM + best_of: Optional[int] = None + top_k: Optional[int] = -1 + ignore_eos: Optional[bool] = False + use_beam_search: Optional[bool] = False + early_stopping: Optional[bool] = False + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + skip_special_tokens: Optional[bool] = True + spaces_between_special_tokens: Optional[bool] = True + add_generation_prompt: Optional[bool] = True + echo: Optional[bool] = False + repetition_penalty: Optional[float] = 1.0 + min_p: Optional[float] = 0.0 + include_stop_str_in_output: Optional[bool] = False + length_penalty: Optional[float] = 1.0 + guided_json: Optional[Union[str, dict, BaseModel]] = None + guided_regex: Optional[str] = None + guided_choice: Optional[List[str]] = None + + def to_sampling_params(self) -> SamplingParams: + if self.logprobs and not self.top_logprobs: + raise ValueError("Top logprobs must be set when logprobs is.") + + logits_processors = None + if self.logit_bias: + + def logit_bias_logits_processor( + token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in self.logit_bias.items(): + # Clamp the bias between -100 and 100 per OpenAI API spec + bias = min(100, max(-100, bias)) + logits[int(token_id)] += bias + return logits + + logits_processors = [logit_bias_logits_processor] + + return SamplingParams( + n=self.n, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + min_p=self.min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + max_tokens=self.max_tokens, + logprobs=self.top_logprobs if self.logprobs else None, + prompt_logprobs=self.top_logprobs if self.echo else None, + best_of=self.best_of, + top_k=self.top_k, + ignore_eos=self.ignore_eos, + use_beam_search=self.use_beam_search, + early_stopping=self.early_stopping, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=self.spaces_between_special_tokens, + include_stop_str_in_output=self.include_stop_str_in_output, + length_penalty=self.length_penalty, + logits_processors=logits_processors, + ) + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + + +class CompletionRequest(BaseModel): + model: str + # a string, array of strings, array of tokens, or array of token arrays + prompt: Union[List[int], List[List[int]], str, List[str]] + suffix: Optional[str] = None + max_tokens: Optional[int] = 16 + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + n: Optional[int] = 1 + stream: Optional[bool] = False + logprobs: Optional[int] = None + echo: Optional[bool] = False + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + seed: Optional[int] = None + presence_penalty: Optional[float] = 0.0 + frequency_penalty: Optional[float] = 0.0 + best_of: Optional[int] = None + logit_bias: Optional[Dict[str, float]] = None + user: Optional[str] = None + # Additional parameters supported by vLLM + top_k: Optional[int] = -1 + ignore_eos: Optional[bool] = False + use_beam_search: Optional[bool] = False + early_stopping: Optional[bool] = False + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + skip_special_tokens: Optional[bool] = True + spaces_between_special_tokens: Optional[bool] = True + repetition_penalty: Optional[float] = 1.0 + min_p: Optional[float] = 0.0 + include_stop_str_in_output: Optional[bool] = False + length_penalty: Optional[float] = 1.0 + guided_json: Optional[Union[str, dict, BaseModel]] = None + guided_regex: Optional[str] = None + guided_choice: Optional[List[str]] = None + + def to_sampling_params(self): + echo_without_generation = self.echo and self.max_tokens == 0 + + logits_processors = None + if self.logit_bias: + + def logit_bias_logits_processor( + token_ids: List[int], + logits: torch.Tensor) -> torch.Tensor: + for token_id, bias in self.logit_bias.items(): + # Clamp the bias between -100 and 100 per OpenAI API spec + bias = min(100, max(-100, bias)) + logits[int(token_id)] += bias + return logits + + logits_processors = [logit_bias_logits_processor] + + return SamplingParams( + n=self.n, + best_of=self.best_of, + presence_penalty=self.presence_penalty, + frequency_penalty=self.frequency_penalty, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=self.top_k, + min_p=self.min_p, + seed=self.seed, + stop=self.stop, + stop_token_ids=self.stop_token_ids, + ignore_eos=self.ignore_eos, + max_tokens=self.max_tokens if not echo_without_generation else 1, + logprobs=self.logprobs, + use_beam_search=self.use_beam_search, + early_stopping=self.early_stopping, + prompt_logprobs=self.logprobs if self.echo else None, + skip_special_tokens=self.skip_special_tokens, + spaces_between_special_tokens=(self.spaces_between_special_tokens), + include_stop_str_in_output=self.include_stop_str_in_output, + length_penalty=self.length_penalty, + logits_processors=logits_processors, + ) + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + + +class LogProbs(BaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None + + +class CompletionResponseChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: UsageInfo + + +class CompletionResponseStreamChoice(BaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class CompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class ChatMessage(BaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(BaseModel): + index: int + message: ChatMessage + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: UsageInfo + + +class DeltaMessage(BaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(BaseModel): + index: int + delta: DeltaMessage + logprobs: Optional[LogProbs] = None + finish_reason: Optional[Literal["stop", "length"]] = None + + +class ChatCompletionStreamResponse(BaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py new file mode 100644 index 0000000..f4ad0aa --- /dev/null +++ b/vllm/entrypoints/openai/serving_chat.py @@ -0,0 +1,307 @@ +import time +import codecs +from fastapi import Request +from typing import AsyncGenerator, AsyncIterator, Optional, List, Union +from vllm.logger import init_logger +from vllm.utils import random_uuid +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, + UsageInfo) +from vllm.outputs import RequestOutput +from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA +from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor + +logger = init_logger(__name__) + + +class OpenAIServingChat(OpenAIServing): + + def __init__(self, + engine: AsyncLLMEngine, + served_model: str, + response_role: str, + lora_modules: Optional[List[LoRA]] = None, + chat_template=None): + super().__init__(engine=engine, + served_model=served_model, + lora_modules=lora_modules) + self.response_role = response_role + self._load_chat_template(chat_template) + + async def create_chat_completion( + self, request: ChatCompletionRequest, raw_request: Request + ) -> Union[ErrorResponse, AsyncGenerator[str, None], + ChatCompletionResponse]: + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/chat/create + for the API specification. This API mimics the OpenAI ChatCompletion API. + + NOTE: Currently we do not support the following feature: + - function_call (Users should implement this by themselves) + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + try: + prompt = self.tokenizer.apply_chat_template( + conversation=request.messages, + tokenize=False, + add_generation_prompt=request.add_generation_prompt) + except Exception as e: + logger.error( + f"Error in applying chat template from request: {str(e)}") + return self.create_error_response(str(e)) + + request_id = f"cmpl-{random_uuid()}" + try: + token_ids = self._validate_prompt_and_tokenize(request, + prompt=prompt) + sampling_params = request.to_sampling_params() + lora_request = self._maybe_get_lora(request) + guided_decode_logits_processor = ( + await get_guided_decoding_logits_processor( + request, self.engine.get_tokenizer())) + if guided_decode_logits_processor: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append( + guided_decode_logits_processor) + except ValueError as e: + return self.create_error_response(str(e)) + + result_generator = self.engine.generate(prompt, sampling_params, + request_id, token_ids, + lora_request) + # Streaming response + if request.stream: + return self.chat_completion_stream_generator( + request, result_generator, request_id) + else: + return await self.chat_completion_full_generator( + request, raw_request, result_generator, request_id) + + def get_chat_request_role(self, request: ChatCompletionRequest) -> str: + if request.add_generation_prompt: + return self.response_role + else: + return request.messages[-1]["role"] + + async def chat_completion_stream_generator( + self, request: ChatCompletionRequest, + result_generator: AsyncIterator[RequestOutput], request_id: str + ) -> Union[ErrorResponse, AsyncGenerator[str, None]]: + + model_name = request.model + created_time = int(time.monotonic()) + chunk_object_type = "chat.completion.chunk" + + # Send first response for each request.n (index) with the role + role = self.get_chat_request_role(request) + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(role=role), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse(id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response to echo the input portion of the last message + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, list) and request.messages[-1].get( + "content") and request.messages[-1].get( + "role") == role: + last_msg_content = request.messages[-1]["content"] + + if last_msg_content: + for i in range(request.n): + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=last_msg_content), + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + logprobs=None, + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Send response for each token for each request.n (index) + previous_texts = [""] * request.n + previous_num_tokens = [0] * request.n + finish_reason_sent = [False] * request.n + async for res in result_generator: + res: RequestOutput + for output in res.outputs: + i = output.index + + if finish_reason_sent[i]: + continue + + delta_token_ids = output.token_ids[previous_num_tokens[i]:] + top_logprobs = output.logprobs[ + previous_num_tokens[i]:] if output.logprobs else None + + if request.logprobs: + logprobs = self._create_logprobs( + token_ids=delta_token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=len(previous_texts[i]), + ) + else: + logprobs = None + + delta_text = output.text[len(previous_texts[i]):] + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + if output.finish_reason is None: + # Send token-by-token response for each request.n + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + logprobs=logprobs, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + else: + # Send the finish response for each request.n only once + prompt_tokens = len(res.prompt_token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=previous_num_tokens[i], + total_tokens=prompt_tokens + previous_num_tokens[i], + ) + choice_data = ChatCompletionResponseStreamChoice( + index=i, + delta=DeltaMessage(content=delta_text), + logprobs=logprobs, + finish_reason=output.finish_reason) + chunk = ChatCompletionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + if final_usage is not None: + chunk.usage = final_usage + data = chunk.model_dump_json(exclude_unset=True, + exclude_none=True) + yield f"data: {data}\n\n" + finish_reason_sent[i] = True + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n" + + async def chat_completion_full_generator( + self, request: ChatCompletionRequest, raw_request: Request, + result_generator: AsyncIterator[RequestOutput], + request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]: + + model_name = request.model + created_time = int(time.monotonic()) + final_res: RequestOutput = None + + async for res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(request_id) + return self.create_error_response("Client disconnected") + final_res = res + assert final_res is not None + + choices = [] + + role = self.get_chat_request_role(request) + for output in final_res.outputs: + token_ids = output.token_ids + top_logprobs = output.logprobs + + if request.logprobs: + logprobs = self._create_logprobs( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + + choice_data = ChatCompletionResponseChoice( + index=output.index, + message=ChatMessage(role=role, content=output.text), + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + if request.echo: + last_msg_content = "" + if request.messages and isinstance( + request.messages, list) and request.messages[-1].get( + "content") and request.messages[-1].get( + "role") == role: + last_msg_content = request.messages[-1]["content"] + + for choice in choices: + full_message = last_msg_content + choice.message.content + choice.message.content = full_message + + num_prompt_tokens = len(final_res.prompt_token_ids) + num_generated_tokens = sum( + len(output.token_ids) for output in final_res.outputs) + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + response = ChatCompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + return response + + def _load_chat_template(self, chat_template): + if chat_template is not None: + try: + with open(chat_template, "r") as f: + self.tokenizer.chat_template = f.read() + except OSError: + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + self.tokenizer.chat_template = codecs.decode( + chat_template, "unicode_escape") + + logger.info( + f"Using supplied chat template:\n{self.tokenizer.chat_template}" + ) + elif self.tokenizer.chat_template is not None: + logger.info( + f"Using default chat template:\n{self.tokenizer.chat_template}" + ) + else: + logger.warning( + "No chat template provided. Chat API will not work.") diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py new file mode 100644 index 0000000..99a1019 --- /dev/null +++ b/vllm/entrypoints/openai/serving_completion.py @@ -0,0 +1,361 @@ +import asyncio +import time +from fastapi import Request +from typing import AsyncGenerator, AsyncIterator, Callable, List, Optional, Dict, Tuple +from vllm.logger import init_logger +from vllm.utils import random_uuid +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import ( + CompletionRequest, + CompletionResponse, + CompletionResponseChoice, + CompletionResponseStreamChoice, + CompletionStreamResponse, + LogProbs, + UsageInfo, +) +from vllm.outputs import RequestOutput +from vllm.entrypoints.openai.serving_engine import OpenAIServing, LoRA +from vllm.model_executor.guided_decoding import get_guided_decoding_logits_processor + +logger = init_logger(__name__) + +TypeTokenIDs = List[int] +TypeTopLogProbs = List[Optional[Dict[int, float]]] +TypeCreateLogProbsFn = Callable[ + [TypeTokenIDs, TypeTopLogProbs, Optional[int], int], LogProbs] + + +async def completion_stream_generator( + request: CompletionRequest, + raw_request: Request, + on_abort, + result_generator: AsyncIterator[Tuple[int, RequestOutput]], + create_logprobs_fn: TypeCreateLogProbsFn, + request_id: str, + created_time: int, + model_name: str, + num_prompts: int, +) -> AsyncGenerator[str, None]: + previous_texts = [""] * request.n * num_prompts + previous_num_tokens = [0] * request.n * num_prompts + has_echoed = [False] * request.n * num_prompts + + async for prompt_idx, res in result_generator: + + # Abort the request if the client disconnects. + if await raw_request.is_disconnected(): + await on_abort(f"{request_id}-{prompt_idx}") + raise StopAsyncIteration() + + for output in res.outputs: + i = output.index + prompt_idx * request.n + # TODO(simon): optimize the performance by avoiding full text O(n^2) sending. + + if request.echo and request.max_tokens == 0: + # only return the prompt + delta_text = res.prompt + delta_token_ids = res.prompt_token_ids + top_logprobs = res.prompt_logprobs + has_echoed[i] = True + elif request.echo and request.max_tokens > 0 and not has_echoed[i]: + # echo the prompt and first token + delta_text = res.prompt + output.text + delta_token_ids = res.prompt_token_ids + output.token_ids + top_logprobs = res.prompt_logprobs + (output.logprobs or []) + has_echoed[i] = True + else: + # return just the delta + delta_text = output.text[len(previous_texts[i]):] + delta_token_ids = output.token_ids[previous_num_tokens[i]:] + top_logprobs = output.logprobs[ + previous_num_tokens[i]:] if output.logprobs else None + + if request.logprobs is not None: + assert top_logprobs is not None, "top_logprobs must be provided when logprobs is requested" + logprobs = create_logprobs_fn( + token_ids=delta_token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + initial_text_offset=len(previous_texts[i]), + ) + else: + logprobs = None + + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + finish_reason = output.finish_reason + response_json = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text=delta_text, + logprobs=logprobs, + finish_reason=finish_reason, + ) + ]).model_dump_json() + yield f"data: {response_json}\n\n" + + if output.finish_reason is not None: # return final usage + logprobs = LogProbs() if request.logprobs is not None else None + prompt_tokens = len(res.prompt_token_ids) + completion_tokens = len(output.token_ids) + final_usage = UsageInfo( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + response_json = CompletionStreamResponse( + id=request_id, + created=created_time, + model=model_name, + choices=[ + CompletionResponseStreamChoice( + index=i, + text="", + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + ], + usage=final_usage, + ).model_dump_json() + yield f"data: {response_json}\n\n" + + yield "data: [DONE]\n\n" + + +def parse_prompt_format(prompt) -> Tuple[bool, list]: + # get the prompt, openai supports the following + # "a string, array of strings, array of tokens, or array of token arrays." + prompt_is_tokens = False + prompts = [prompt] # case 1: a string + if isinstance(prompt, list): + if len(prompt) == 0: + raise ValueError("please provide at least one prompt") + elif isinstance(prompt[0], str): + prompt_is_tokens = False + prompts = prompt # case 2: array of strings + elif isinstance(prompt[0], int): + prompt_is_tokens = True + prompts = [prompt] # case 3: array of tokens + elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int): + prompt_is_tokens = True + prompts = prompt # case 4: array of token arrays + else: + raise ValueError( + "prompt must be a string, array of strings, array of tokens, or array of token arrays" + ) + return prompt_is_tokens, prompts + + +def request_output_to_completion_response( + final_res_batch: List[RequestOutput], + request: CompletionRequest, + create_logprobs_fn: TypeCreateLogProbsFn, + request_id: str, + created_time: int, + model_name: str, +) -> CompletionResponse: + choices = [] + num_prompt_tokens = 0 + num_generated_tokens = 0 + for final_res in final_res_batch: + assert final_res is not None + prompt_token_ids = final_res.prompt_token_ids + prompt_logprobs = final_res.prompt_logprobs + prompt_text = final_res.prompt + + for output in final_res.outputs: + if request.echo and request.max_tokens == 0: + token_ids = prompt_token_ids + top_logprobs = prompt_logprobs + output_text = prompt_text + elif request.echo and request.max_tokens > 0: + token_ids = prompt_token_ids + output.token_ids + top_logprobs = prompt_logprobs + output.logprobs + output_text = prompt_text + output.text + else: + token_ids = output.token_ids + top_logprobs = output.logprobs + output_text = output.text + + if request.logprobs is not None: + logprobs = create_logprobs_fn( + token_ids=token_ids, + top_logprobs=top_logprobs, + num_output_top_logprobs=request.logprobs, + ) + else: + logprobs = None + + choice_data = CompletionResponseChoice( + index=len(choices), + text=output_text, + logprobs=logprobs, + finish_reason=output.finish_reason, + ) + choices.append(choice_data) + + num_prompt_tokens += len(prompt_token_ids) + num_generated_tokens += sum( + len(output.token_ids) for output in final_res.outputs) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=num_generated_tokens, + total_tokens=num_prompt_tokens + num_generated_tokens, + ) + + return CompletionResponse( + id=request_id, + created=created_time, + model=model_name, + choices=choices, + usage=usage, + ) + + +def merge_async_iterators(*iterators): + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + """ + queue = asyncio.Queue() + + finished = [False] * len(iterators) + + async def producer(i, iterator): + async for item in iterator: + await queue.put((i, item)) + finished[i] = True + + _tasks = [ + asyncio.create_task(producer(i, iterator)) + for i, iterator in enumerate(iterators) + ] + + async def consumer(): + while not all(finished) or not queue.empty(): + item = await queue.get() + yield item + await asyncio.gather(*_tasks) + + return consumer() + + +class OpenAIServingCompletion(OpenAIServing): + + def __init__(self, + engine: AsyncLLMEngine, + served_model: str, + lora_modules: Optional[List[LoRA]] = None): + super().__init__(engine=engine, + served_model=served_model, + lora_modules=lora_modules) + + async def create_completion(self, request: CompletionRequest, + raw_request: Request): + """Completion API similar to OpenAI's API. + + See https://platform.openai.com/docs/api-reference/completions/create + for the API specification. This API mimics the OpenAI Completion API. + + NOTE: Currently we do not support the following feature: + - suffix (the language models we currently support do not support + suffix) + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + # Return error for unsupported features. + if request.suffix is not None: + return self.create_error_response( + "suffix is not currently supported") + + model_name = request.model + request_id = f"cmpl-{random_uuid()}" + created_time = int(time.monotonic()) + + # Schedule the request and get the result generator. + generators = [] + try: + sampling_params = request.to_sampling_params() + lora_request = self._maybe_get_lora(request) + guided_decode_logit_processor = ( + await get_guided_decoding_logits_processor( + request, self.engine.get_tokenizer())) + if guided_decode_logit_processor is not None: + if sampling_params.logits_processors is None: + sampling_params.logits_processors = [] + sampling_params.logits_processors.append( + guided_decode_logit_processor) + prompt_is_tokens, prompts = parse_prompt_format(request.prompt) + + for i, prompt in enumerate(prompts): + if prompt_is_tokens: + input_ids = self._validate_prompt_and_tokenize( + request, prompt_ids=prompt) + else: + input_ids = self._validate_prompt_and_tokenize( + request, prompt=prompt) + + generators.append( + self.engine.generate(prompt, + sampling_params, + f"{request_id}-{i}", + prompt_token_ids=input_ids, + lora_request=lora_request)) + except ValueError as e: + return self.create_error_response(str(e)) + + result_generator: AsyncIterator[Tuple[ + int, RequestOutput]] = merge_async_iterators(*generators) + + # Similar to the OpenAI API, when n != best_of, we do not stream the + # results. In addition, we do not stream the results when use beam search. + stream = (request.stream + and (request.best_of is None or request.n == request.best_of) + and not request.use_beam_search) + + # Streaming response + if stream: + return completion_stream_generator(request, + raw_request, + self.engine.abort, + result_generator, + self._create_logprobs, + request_id, + created_time, + model_name, + num_prompts=len(prompts)) + + # Non-streaming response + final_res_batch: RequestOutput = [None] * len(prompts) + async for i, res in result_generator: + if await raw_request.is_disconnected(): + # Abort the request if the client disconnects. + await self.engine.abort(f"{request_id}-{i}") + return self.create_error_response("Client disconnected") + final_res_batch[i] = res + response = request_output_to_completion_response( + final_res_batch, request, self._create_logprobs, request_id, + created_time, model_name) + + # When user requests streaming but we don't stream, we still need to + # return a streaming response with a single event. + if request.stream: + response_json = response.model_dump_json() + + async def fake_stream_generator() -> AsyncGenerator[str, None]: + yield f"data: {response_json}\n\n" + yield "data: [DONE]\n\n" + + return fake_stream_generator() + + return response diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py new file mode 100644 index 0000000..0994547 --- /dev/null +++ b/vllm/entrypoints/openai/serving_engine.py @@ -0,0 +1,172 @@ +import asyncio +from dataclasses import dataclass +from http import HTTPStatus +from typing import Dict, List, Optional, Union +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import (CompletionRequest, + ChatCompletionRequest, + ErrorResponse, LogProbs, + ModelCard, ModelList, + ModelPermission) +from vllm.lora.request import LoRARequest + +logger = init_logger(__name__) + + +@dataclass +class LoRA: + name: str + local_path: str + + +class OpenAIServing: + + def __init__(self, + engine: AsyncLLMEngine, + served_model: str, + lora_modules=Optional[List[LoRA]]): + self.engine = engine + self.served_model = served_model + if lora_modules is None: + self.lora_requests = [] + else: + self.lora_requests = [ + LoRARequest( + lora_name=lora.name, + lora_int_id=i, + lora_local_path=lora.local_path, + ) for i, lora in enumerate(lora_modules, start=1) + ] + + self.max_model_len = 0 + self.tokenizer = None + + try: + event_loop = asyncio.get_running_loop() + except RuntimeError: + event_loop = None + + if event_loop is not None and event_loop.is_running( + ): # If the current is instanced by Ray Serve, there is already a running event loop + event_loop.create_task(self._post_init()) + else: # When using single vLLM without engine_use_ray + asyncio.run(self._post_init()) + + async def _post_init(self): + engine_model_config = await self.engine.get_model_config() + self.max_model_len = engine_model_config.max_model_len + + # A separate tokenizer to map token IDs to strings. + self.tokenizer = get_tokenizer( + engine_model_config.tokenizer, + tokenizer_mode=engine_model_config.tokenizer_mode, + trust_remote_code=engine_model_config.trust_remote_code) + + async def show_available_models(self) -> ModelList: + """Show available models. Right now we only have one model.""" + model_cards = [ + ModelCard(id=self.served_model, + root=self.served_model, + permission=[ModelPermission()]) + ] + lora_cards = [ + ModelCard(id=lora.lora_name, + root=self.served_model, + permission=[ModelPermission()]) + for lora in self.lora_requests + ] + model_cards.extend(lora_cards) + return ModelList(data=model_cards) + + def _create_logprobs( + self, + token_ids: List[int], + top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None, + num_output_top_logprobs: Optional[int] = None, + initial_text_offset: int = 0, + ) -> LogProbs: + """Create OpenAI-style logprobs.""" + logprobs = LogProbs() + last_token_len = 0 + if num_output_top_logprobs: + logprobs.top_logprobs = [] + for i, token_id in enumerate(token_ids): + step_top_logprobs = top_logprobs[i] + if step_top_logprobs is not None: + token_logprob = step_top_logprobs[token_id] + else: + token_logprob = None + token = self.tokenizer.convert_ids_to_tokens(token_id) + logprobs.tokens.append(token) + logprobs.token_logprobs.append(token_logprob) + if len(logprobs.text_offset) == 0: + logprobs.text_offset.append(initial_text_offset) + else: + logprobs.text_offset.append(logprobs.text_offset[-1] + + last_token_len) + last_token_len = len(token) + + if num_output_top_logprobs: + logprobs.top_logprobs.append({ + self.tokenizer.convert_ids_to_tokens(i): p + for i, p in step_top_logprobs.items() + } if step_top_logprobs else None) + return logprobs + + def create_error_response( + self, + message: str, + err_type: str = "BadRequestError", + status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse: + return ErrorResponse(message=message, + type=err_type, + code=status_code.value) + + async def _check_model(self, request) -> Optional[ErrorResponse]: + if request.model == self.served_model: + return + if request.model in [lora.lora_name for lora in self.lora_requests]: + return + return self.create_error_response( + message=f"The model `{request.model}` does not exist.", + err_type="NotFoundError", + status_code=HTTPStatus.NOT_FOUND) + + def _maybe_get_lora(self, request) -> Optional[LoRARequest]: + if request.model == self.served_model: + return + for lora in self.lora_requests: + if request.model == lora.lora_name: + return lora + # if _check_model has been called earlier, this will be unreachable + raise ValueError("The model `{request.model}` does not exist.") + + def _validate_prompt_and_tokenize( + self, + request: Union[ChatCompletionRequest, CompletionRequest], + prompt: Optional[str] = None, + prompt_ids: Optional[List[int]] = None) -> List[int]: + if not (prompt or prompt_ids): + raise ValueError("Either prompt or prompt_ids should be provided.") + if (prompt and prompt_ids): + raise ValueError( + "Only one of prompt or prompt_ids should be provided.") + + input_ids = prompt_ids if prompt_ids is not None else self.tokenizer( + prompt).input_ids + token_num = len(input_ids) + + if request.max_tokens is None: + request.max_tokens = self.max_model_len - token_num + + if token_num + request.max_tokens > self.max_model_len: + raise ValueError( + f"This model's maximum context length is {self.max_model_len} tokens. " + f"However, you requested {request.max_tokens + token_num} tokens " + f"({token_num} in the messages, " + f"{request.max_tokens} in the completion). " + f"Please reduce the length of the messages or completion.", ) + else: + return input_ids diff --git a/vllm/logger.py b/vllm/logger.py new file mode 100644 index 0000000..d25fcef --- /dev/null +++ b/vllm/logger.py @@ -0,0 +1,61 @@ +# Adapted from +# https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py +"""Logging configuration for vLLM.""" +import logging +import sys +import os + +VLLM_CONFIGURE_LOGGING = int(os.getenv("VLLM_CONFIGURE_LOGGING", "1")) + +_FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" +_DATE_FORMAT = "%m-%d %H:%M:%S" + + +class NewLineFormatter(logging.Formatter): + """Adds logging prefix to newlines to align multi-line messages.""" + + def __init__(self, fmt, datefmt=None): + logging.Formatter.__init__(self, fmt, datefmt) + + def format(self, record): + msg = logging.Formatter.format(self, record) + if record.message != "": + parts = msg.split(record.message) + msg = msg.replace("\n", "\r\n" + parts[0]) + return msg + + +_root_logger = logging.getLogger("vllm") +_default_handler = None + + +def _setup_logger(): + _root_logger.setLevel(logging.DEBUG) + global _default_handler + if _default_handler is None: + _default_handler = logging.StreamHandler(sys.stdout) + _default_handler.flush = sys.stdout.flush # type: ignore + _default_handler.setLevel(logging.INFO) + _root_logger.addHandler(_default_handler) + fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) + _default_handler.setFormatter(fmt) + # Setting this will avoid the message + # being propagated to the parent logger. + _root_logger.propagate = False + + +# The logger is initialized when the module is imported. +# This is thread-safe as the module is only imported once, +# guaranteed by the Python GIL. +if VLLM_CONFIGURE_LOGGING: + _setup_logger() + + +def init_logger(name: str): + # Use the same settings as above for root logger + logger = logging.getLogger(name) + logger.setLevel(os.getenv("LOG_LEVEL", "DEBUG")) + if VLLM_CONFIGURE_LOGGING: + logger.addHandler(_default_handler) + logger.propagate = False + return logger diff --git a/vllm/lora/__init__.py b/vllm/lora/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py new file mode 100644 index 0000000..e667d70 --- /dev/null +++ b/vllm/lora/layers.py @@ -0,0 +1,979 @@ +# pylint: disable=unused-argument +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.lora.punica import add_lora, add_lora_slice, bgmv +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_gather, +) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear, + QKVParallelLinear, + MergedColumnParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim + +if TYPE_CHECKING: + pass + + +def _apply_lora( + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + indices: torch.Tensor, + output: torch.Tensor, +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: (num_loras, lora_rank, hidden_dim) + lora_b_stacked: (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, output_dim) + """ + org_output = output + x = x.view(-1, x.shape[-1]) + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) + return output.view_as(org_output) + + +def _apply_lora_packed_nslice( + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + This method is used for layers that are composed of multiple sublayers + (slices) packed together. + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim) + lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), where n is number of slices + """ + org_output = output + x = x.view(-1, x.shape[-1]) + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + offset_left = 0 + for slice_idx in range(len(output_slices)): + add_lora_slice(output, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, + output_slices[slice_idx]) + offset_left += output_slices[slice_idx] + return output.view_as(org_output) + + +@dataclass +class LoRAMapping: + # Per every token in input_ids: + index_mapping: Tuple[int, ...] + # Per sampled token: + prompt_mapping: Tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) + + +class BaseLayerWithLoRA(nn.Module): + + def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, + model_config: PretrainedConfig) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + """Sets the mapping indices.""" + ... + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + lora_vocab_start_idx = self.base_layer.org_vocab_size + weights_idx = None + if self.base_layer.vocab_end_index > lora_vocab_start_idx: + # We can start adding lora weights + weights_idx = max( + lora_vocab_start_idx - self.base_layer.vocab_start_index, 0) + self.embeddings_slice = (self.base_layer.vocab_start_index - + self.base_layer.org_vocab_size + + weights_idx, + self.base_layer.vocab_end_index - + self.base_layer.org_vocab_size) + self.embeddings_weights = self.base_layer.weight.data[weights_idx:] + self.embeddings_weights.fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + self.embeddings_indices = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1]].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2] + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.embeddings_indices = embeddings_indices + self.indices_len = indices_len + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = x > self.base_layer.org_vocab_size - 1 + indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) + full_lora_a_embeddings = F.embedding( + x + indices, + self.lora_a_stacked_2d, + ) + indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) + full_output = self.base_layer.forward( + x.add_(indices * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], -1) + bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + return full_output.view_as(full_output_org) + + +class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_a_stacked = torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + max_loras, + 1, + self.base_layer.weight.shape[0], + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + self.output_dim = self.lora_b_stacked.shape[1] + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.indices_len = indices_len + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + ) + return output + + def forward(self, input_): + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply_weights(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @property + def linear_weights(self): + return self.base_layer.linear_weights + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__(self, base_layer: MergedColumnParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + n_slices = 2 + if not (len(self.base_layer.output_sizes) == n_slices + and self.base_layer.output_sizes[0] + == self.base_layer.output_sizes[1]): + raise ValueError( + "LoRAColumnParallelLinear2Slice requires 2 slices with " + "the same size.") + self.tp_size = get_tensor_model_parallel_world_size() + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) for _ in range(n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + self.base_layer.weight.shape[0] // 2, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) for _ in range(n_slices)) + + self.indices: Optional[torch.Tensor] = None + self.output_dim = self.lora_b_stacked[0].shape[2] + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_b_stacked[1][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[0][:, + start_idx:end_idx], lora_b[1][:, + start_idx:end_idx] + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora_packed_nslice( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + (self.output_dim, self.output_dim), + ) + return output + + +class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + + # q, k, v + self.lora_a_stacked = ( + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + ) + self.lora_b_stacked = ( + torch.zeros( + max_loras, + 1, + self.q_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + ) + + self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size, + self.kv_proj_shard_size) + self.packed_indices: Optional[torch.Tensor] = None + self.standard_indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[1][index] = 0 + self.lora_a_stacked[2][index] = 0 + self.lora_b_stacked[2][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + if lora_b[0] is not None: + lora_b_q = lora_b[0][:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + self.lora_b_stacked[0][ + index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( + lora_b_q.T, non_blocking=True) + if lora_b[1] is not None: + lora_b_k = lora_b[1][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + self.lora_b_stacked[1][ + index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( + lora_b_k.T, non_blocking=True) + if lora_b[2] is not None: + lora_b_v = lora_b[2][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + self.lora_b_stacked[2][ + index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( + lora_b_v.T, non_blocking=True) + else: + if lora_b[0] is not None: + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_b[1] is not None: + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + if lora_b[2] is not None: + self.lora_b_stacked[2][ + index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_( + lora_b[2].T, non_blocking=True) + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + if lora_a[2] is not None: + self.lora_a_stacked[2][ + index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( + lora_a[2].T, non_blocking=True) + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora_packed_nslice( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + self.output_slices, + ) + return output + + +class RowParallelLinearWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.weight.shape[0], + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + if self.base_layer.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.base_layer.weight.shape[1] + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.indices_len = indices_len + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x) + _apply_lora( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + ) + return output + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply_weights(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + @property + def weight(self): + return self.base_layer.weight + + +class SamplerWithLoRA(BaseLayerWithLoRA): + + def __init__( + self, + base_layer: Sampler, + hidden_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + + @property + def logits_as_hidden_states(self): + return self.base_layer.logits_as_hidden_states + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + if 32000 < self.base_layer.vocab_size > 33024: + raise ValueError( + "When using LoRA, vocab size must be 32000 >= vocab_size <= 33024" + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + self.indices = None + self.indices_padded = None + self.indices_len = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1], ] = embeddings_tensor + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = sampler_indices + self.indices_padded = sampler_indices_padded + self.indices_len = indices_len + + def _get_logits( + self, + hidden_states: torch.Tensor, + embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_gather(logits) + if logits is None: + return None + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + lora_logits[-1] = float("-inf") + lora_logits = lora_logits.mT + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, + self.indices_padded[:self.indices_len[2]]).nan_to_num_( + nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + + _apply_lora( + hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[1]], + logits, + ) + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + +def from_layer( + layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA: + supported_layer_types = { + VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, + ColumnParallelLinear: ColumnParallelLinearWithLoRA, + QKVParallelLinear: QKVParallelLinearWithLora, + MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, + RowParallelLinear: RowParallelLinearWithLoRA, + } + for src_layer_type, lora_layer_type in supported_layer_types.items(): + if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck + ret = lora_layer_type(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer + + +def from_layer_sampler( + layer: Sampler, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> SamplerWithLoRA: + ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype, + lm_head.weight.device) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py new file mode 100644 index 0000000..fbb228c --- /dev/null +++ b/vllm/lora/lora.py @@ -0,0 +1,160 @@ +from typing import List, Optional + +import torch +from vllm.utils import in_wsl + + +class LoRALayerWeights: + """LoRA weights for a layer composed of two low rank matrixes.""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alpha: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + ) -> None: + self.module_name = module_name + self.rank = rank + self.lora_alpha = lora_alpha + self.lora_a = lora_a + self.lora_b = lora_b + self.embeddings_tensor = embeddings_tensor + + if scaling is None: + self.scaling = self.lora_alpha / self.rank + else: + self.scaling = scaling + + def optimize(self) -> "LoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + if self.scaling == 1: + return + self.lora_b *= self.scaling + self.scaling = 1 + return self + + @property + def input_dim(self) -> int: + return self.lora_a.shape[0] + + @property + def output_dim(self) -> int: + return self.lora_b.shape[1] + + @property + def is_packed(self) -> bool: + return False + + @property + def extra_vocab_size(self) -> int: + return self.embeddings_tensor.shape[ + 0] if self.embeddings_tensor is not None else 0 + + @classmethod + def create_dummy_lora_weights( + cls, + module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": + pin_memory = str(device) == "cpu" and not in_wsl() + lora_a = torch.zeros([input_dim, rank], + dtype=dtype, + device=device, + pin_memory=pin_memory) + lora_b = torch.zeros([rank, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + embeddings_tensor = torch.rand( + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory) if embeddings_tensor_dim else None + return cls( + module_name, + rank=rank, + lora_alpha=1, + lora_a=lora_a, + lora_b=lora_b, + embeddings_tensor=embeddings_tensor, + ) + + +class PackedLoRALayerWeights(LoRALayerWeights): + """LoRA used for packed layers (eg. qkv_proj).""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alphas: List[int], + lora_a: List[torch.Tensor], + lora_b: List[torch.Tensor], + scaling: Optional[List[float]] = None, + ) -> None: + super().__init__( + module_name=module_name, + rank=rank, + lora_alpha=0, + lora_a=lora_a, + lora_b=lora_b, + scaling=scaling, + embeddings_tensor=None, + ) + self.lora_alphas = lora_alphas + if scaling is None: + self.scaling = [ + lora_alpha / self.rank for lora_alpha in self.lora_alphas + ] + + @classmethod + def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": + """Pack a list of LoRAs into a single LoRA. + + If LoRA is None, it signifies that the submodule does not have a LoRA. + """ + first_lora = next(lora for lora in loras if lora is not None) + for lora in loras: + if lora is None: + continue + lora.optimize() + rank = first_lora.rank + module_name = first_lora.module_name + obj = cls( + module_name, + rank, + [lora.lora_alpha if lora is not None else None for lora in loras], + [lora.lora_a if lora is not None else None for lora in loras], + [lora.lora_b if lora is not None else None for lora in loras], + scaling=[1 if lora is not None else None for lora in loras]) + return obj + + def optimize(self) -> "PackedLoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + for i in range(len(self.lora_b)): + if self.scaling[i] == 1 or self.lora_b[i] is None: + continue + self.lora_b[i] *= self.scaling[i] + self.scaling[i] = 1 + return self + + @property + def input_dim(self) -> int: + raise NotImplementedError() + + @property + def output_dim(self) -> int: + raise NotImplementedError() + + @property + def is_packed(self) -> bool: + return True diff --git a/vllm/lora/models.py b/vllm/lora/models.py new file mode 100644 index 0000000..7386d21 --- /dev/null +++ b/vllm/lora/models.py @@ -0,0 +1,620 @@ +import copy +import json +import logging +import math +import os +import re +from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type) + +import safetensors.torch +import torch +from torch import nn + +from vllm.config import LoRAConfig +from vllm.utils import LRUCache, in_wsl + +from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule + +logger = logging.getLogger(__name__) + +_GLOBAL_LORA_ID = 0 + + +def convert_mapping( + mapping: LoRAMapping, lora_index_to_id: List[Optional[int]], + max_loras: int, vocab_size: int, extra_vocab_size: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + indices_len: List of lengths of the above tensors. + """ + indices = list(mapping.index_mapping).copy() + embedding_indices = indices.copy() + lora_indices = indices.copy() + prompt_mapping = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(indices[i]) + if indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if indices[i] > 0 else 0 + indices[i] = i + lora_indices[i] = lora_idx + + indices = torch.tensor([indices, lora_indices, embedding_indices], + dtype=torch.long, + device="cuda") + prompt_mapping = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size) + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = ( + torch.arange( + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + + (sampler_indices_padded * len(sampler_indices_padded))) + indices_len = (base_indices.shape[-1], sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1]) + + return (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, indices_len) + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + +class LoRAModel: + """A LoRA fine-tuned model.""" + + def __init__( + self, + lora_model_id: int, + rank: int, + loras: Dict[str, LoRALayerWeights], + ) -> None: + self.id = lora_model_id + assert (lora_model_id > + 0), f"a valid lora id should be greater than 0, got {self.id}" + self.rank = rank + self.loras: Dict[str, LoRALayerWeights] = loras + + @property + def extra_vocab_size(self) -> int: + return max(lora.extra_vocab_size + for lora in self.loras.values()) if self.loras else 0 + + def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: + """Get LoRA for a given module by name""" + return self.loras.get(module_name, None) + + # (yard1): TODO see if we can derive target_embedding_padding automatically + @classmethod + def from_lora_tensors( + cls, + lora_model_id: int, + rank: int, + lora_alpha: int, + tensors: Dict[str, torch.Tensor], + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + embeddings: Optional[Dict[str, torch.Tensor]] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and not in_wsl() + loras: Dict[str, LoRALayerWeights] = {} + for tensor_name, tensor in tensors.items(): + module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) + if module_name not in loras: + lora_embeddings_tensor = None + if embeddings: + embeddings_module = next( + (k for k in embedding_modules if k in module_name), + None) + if embeddings_module: + lora_embeddings_tensor = embeddings[ + embedding_modules[embeddings_module]].to( + device=device, dtype=dtype) + if pin_memory: + lora_embeddings_tensor = ( + lora_embeddings_tensor.pin_memory()) + loras[module_name] = LoRALayerWeights(module_name, rank, + lora_alpha, None, None, + lora_embeddings_tensor) + if is_lora_a: + loras[module_name].lora_a = tensor.to(device=device, + dtype=dtype).t() + if pin_memory: + loras[module_name].lora_a = loras[ + module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(device=device, + dtype=dtype).t() + if any(name in module_name + for name in embedding_padding_modules + ) and target_embedding_padding is not None: + lora_b = loras[module_name].lora_b + assert target_embedding_padding >= lora_b.shape[1] + addition = target_embedding_padding - lora_b.shape[1] + loras[module_name].lora_b = torch.nn.functional.pad( + lora_b, (0, addition)) + if pin_memory: + loras[module_name].lora_b = loras[ + module_name].lora_b.pin_memory() + + for lora in loras.values(): + lora.optimize() + return cls(lora_model_id, rank, loras) + + @classmethod + def from_local_checkpoint( + cls, + lora_dir: str, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None, + embedding_modules: Optional[Dict[str, str]] = None, + embedding_padding_modules: Optional[List[str]] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint.""" + lora_config_path = os.path.join(lora_dir, "adapter_config.json") + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + new_embeddings_tensor_path = os.path.join( + lora_dir, "new_embeddings.safetensors") + new_embeddings_bin_file_path = os.path.join(lora_dir, + "new_embeddings.bin") + if os.path.isfile(lora_tensor_path): + tensors = safetensors.torch.load_file(lora_tensor_path) + elif os.path.isfile(lora_bin_file_path): + tensors = torch.load(lora_bin_file_path) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + embeddings = None + if os.path.isfile(new_embeddings_tensor_path): + embeddings = safetensors.torch.load_file( + new_embeddings_tensor_path) + elif os.path.isfile(new_embeddings_bin_file_path): + embeddings = torch.load(new_embeddings_bin_file_path) + + with open(lora_config_path) as f: + config = json.load(f) + rank = config["r"] + lora_alpha = config["lora_alpha"] + return cls.from_lora_tensors( + lora_model_id=get_lora_id() + if lora_model_id is None else lora_model_id, + rank=rank, + lora_alpha=lora_alpha, + tensors=tensors, + device=device, + dtype=dtype, + embeddings=embeddings, + target_embedding_padding=target_embedding_padding, + embedding_modules=embedding_modules, + embedding_padding_modules=embedding_padding_modules, + ) + + +class LoRAModelManager: + """A manager that manages multiple LoRA-fine-tuned models.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + ): + """Create a LoRAModelManager and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + vocab_size: the vocab size of the model. + lora_config: the LoRA configuration. + """ + self.lora_config = lora_config + self.max_num_seqs = max_num_seqs + assert self.capacity >= self.lora_slots + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots + self.vocab_size = vocab_size + self.base_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.sampler_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.embeddings_indices = torch.empty(2, + self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.offsets = [] + # 4 is the number of indicies tensors defined above + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices + self.indices_len = [None] * 4 + + self.model: nn.Module = model + if hasattr(self.model, "supported_lora_modules"): + self.supported_lora_modules = copy.deepcopy( + self.model.supported_lora_modules) + self.packed_modules_mapping = copy.deepcopy( + self.model.packed_modules_mapping) + self.packed_modules: Dict[str, List[str]] = {} + self.modules: Dict[str, "BaseLayerWithLoRA"] = {} + self._registered_loras: Dict[int, LoRAModel] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._active_loras: Dict[int, None] = {} + self._last_mapping = None + self._create_lora_modules() + self.model.lora_manager = self + + @property + def capacity(self) -> int: + return self.lora_config.max_cpu_loras + + @property + def lora_slots(self) -> int: + return self.lora_config.max_loras + + def __len__(self) -> int: + return len(self._registered_loras) + + def activate_lora( + self, + lora_id: int, + ) -> bool: + """Move LoRA into a GPU buffer to be used in the forward pass.""" + if lora_id in self._active_loras: + return False + first_free_slot = next( + ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) + if lora_id is None), None) + if first_free_slot is None: + raise ValueError("No free lora slots") + index, _ = first_free_slot + self._active_loras[lora_id] = None + lora_model = self._registered_loras[lora_id] + logger.debug( + f"Activating LoRA. int id: {lora_model.id}, slot index: {index}") + self.lora_index_to_id[index] = lora_model.id + for module_name, module in self.modules.items(): + module_lora = lora_model.get_lora(module_name) + if module_lora: + module_lora.optimize() + module.set_lora(index, module_lora.lora_a, module_lora.lora_b, + module_lora.embeddings_tensor) + else: + module.reset_lora(index) + return True + + def _deactivate_lora(self, lora_id: int): + try: + index = self.lora_index_to_id.index(lora_id) + self.lora_index_to_id[index] = None + except ValueError: + pass + + def deactivate_lora(self, lora_id: int) -> bool: + """Remove a LoRA from a GPU buffer.""" + if lora_id in self._active_loras: + self._deactivate_lora(lora_id) + self._active_loras.pop(lora_id) + return True + return False + + def _add_lora(self, lora: LoRAModel) -> bool: + self._create_merged_loras_inplace(lora) + self._registered_loras[lora.id] = lora + + def add_lora(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager CPU cache.""" + if lora.id not in self._registered_loras: + if len(self._registered_loras) >= self.capacity: + raise RuntimeError("No free LoRA slots.") + self._add_lora(lora) + return True + return False + + def remove_lora(self, lora_id: int) -> bool: + """Remove a LoRAModel from the manager CPU cache.""" + # TODO: should we check active lora? + self.deactivate_lora(lora_id) + return bool(self._registered_loras.pop(lora_id, None)) + + # TODO see if this can be vectorized + def _set_lora_mapping(self, mapping: LoRAMapping) -> None: + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, + indices_len) = convert_mapping(mapping, self.lora_index_to_id, + self.lora_slots + 1, self.vocab_size, + self.lora_config.lora_extra_vocab_size) + self.base_indices[:base_indices.shape[0]].copy_(base_indices) + self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self.embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + # Maintain the reference + self.indices_len[:] = indices_len + + def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None: + if self._last_mapping != lora_mapping: + self._set_lora_mapping(lora_mapping) + self._last_mapping = lora_mapping + + def list_loras(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_loras) + + def get_lora(self, lora_id: int) -> Optional[LoRAModel]: + return self._registered_loras.get(lora_id, None) + + def remove_all_loras(self) -> bool: + """Remove all LoRAModels from the manager.""" + self._registered_loras.clear() + self.lora_index_to_id = [None] * self.lora_slots + self._active_loras.clear() + + def _create_lora_modules(self): + for module_name, module in self.model.named_modules(): + if not self._match_target_modules(module_name): + continue + + new_module = replace_submodule( + self.model, module_name, + from_layer(module, self.lora_slots, self.lora_config, + self.model.config)) + # (yard1): TODO make this more robust + if "lm_head" in module_name: + sampler_module = self.model.get_submodule("sampler") + new_module = replace_submodule( + self.model, "sampler", + from_layer_sampler(sampler_module, module, self.lora_slots, + self.lora_config, self.model.config)) + self.register_module(module_name, new_module) + self._register_packed_modules(module_name) + new_module.set_mapping(self.base_indices, self.sampler_indices, + self.sampler_indices_padded, + self.embeddings_indices, self.indices_len) + + def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): + assert isinstance(module, BaseLayerWithLoRA) + self.modules[module_name] = module + + def create_dummy_lora( + self, + lora_id: int, + rank: int, + embedding_modules: Optional[Dict[str, str]] = None) -> LoRAModel: + """Create zero-initialized LoRAModel for warmup.""" + model = LoRAModel(lora_id, rank, {}) + for module_name, module in self.model.named_modules(): + if not self._match_target_modules(module_name) or not isinstance( + module, BaseLayerWithLoRA): + continue + parts = module_name.split(".") + if module_name not in self.packed_modules: + if parts[-1] in embedding_modules: + input_dim = (module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size if + hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1]) + output_dim = module.base_layer.embedding_dim if hasattr( + module.base_layer, + "embedding_dim") else module.base_layer.weight.shape[0] + embeddings_tensor_dim = (module.base_layer.embedding_dim if + hasattr(module.base_layer, + "embedding_dim") else + module.base_layer.weight.shape[1]) + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + input_dim, + output_dim, + rank, + module.lora_a_stacked.dtype, + "cpu", + embeddings_tensor_dim=embeddings_tensor_dim) + else: + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + module.lora_a_stacked.shape[-1], + module.lora_b_stacked.shape[-2], + rank, + module.lora_a_stacked.dtype, + "cpu", + ) + lora.optimize() + else: + parts = module_name.split(".") + replacements = self.packed_modules_mapping[parts[-1]] + subloras = [] + for i, r in enumerate(replacements): + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name + "." + r, + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], + rank, + module.lora_a_stacked[i].dtype, + "cpu", + ) + lora.optimize() + subloras.append(lora) + lora = PackedLoRALayerWeights.pack(subloras) + model.loras[module_name] = lora + return model + + def _match_target_modules(self, module_name: str): + return any( + re.match( + r".*\.{target_module}$".format(target_module=target_module), + module_name) or target_module == module_name + for target_module in self.supported_lora_modules) + + def _register_packed_modules(self, module_full_name: str) -> None: + parts = module_full_name.split(".") + module_name = parts[-1] + replacements = self.packed_modules_mapping.get(module_name) + if not replacements: + return + prefix = ".".join(parts[:-1]) + self.packed_modules[module_full_name] = [ + prefix + "." + r if prefix else r for r in replacements + ] + + def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: + for module_name, new_module_names in self.packed_modules.items(): + replacement_loras = [] + has_replacement = False + for r in new_module_names: + lora = lora_model.get_lora(r) + replacement_loras.append(lora) + if lora: + has_replacement = True + if not has_replacement: + continue + for i in range(len(replacement_loras)): + if replacement_loras[i]: + continue + replacement_loras[i] = None + lora_model.loras[module_name] = PackedLoRALayerWeights.pack( + replacement_loras) + + +class LoRALRUCache(LRUCache): + + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], + None]): + super().__init__(capacity) + self.deactivate_lora_fn = deactivate_lora_fn + + def _on_remove(self, key: Hashable, value: Any): + logger.debug(f"Removing LoRA. int id: {key}") + self.deactivate_lora_fn(key) + return super()._on_remove(key, value) + + +class LRUCacheLoRAModelManager(LoRAModelManager): + """A model manager that manages multiple LoRAs with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + ): + super().__init__(model, max_num_seqs, max_num_batched_tokens, + vocab_size, lora_config) + self._registered_loras: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_lora) + self._active_loras: LoRALRUCache = LoRALRUCache( + self.lora_slots, self._deactivate_lora) + + def list_loras(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_loras.cache) + + def add_lora(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + if lora.id not in self._registered_loras: + self._add_lora(lora) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_loras.touch(lora.id) + was_added = False + return was_added + + def activate_lora( + self, + lora_id: int, + ) -> bool: + if lora_id not in self._active_loras and len( + self._active_loras) >= self.lora_slots: + self._active_loras.remove_oldest() + result = super().activate_lora(lora_id) + # We always touch to update the LRU cache order + self._active_loras.touch(lora_id) + return result + + def remove_oldest_lora(self) -> bool: + if len(self._registered_loras) > 0: + self._registered_loras.remove_oldest() + return True + return False + + +def create_lora_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, + **kwargs) -> LoRAModelManager: + """Create a LoRA adapter for a given model.""" + if not hasattr(model, "supported_lora_modules"): + raise ValueError(f"Model {type(model)} is not supported for LoRA.") + lora_manager = lora_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + vocab_size=vocab_size, + lora_config=lora_config, + **kwargs) + return lora_manager diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py new file mode 100644 index 0000000..fc74269 --- /dev/null +++ b/vllm/lora/punica.py @@ -0,0 +1,170 @@ +# Based on code from https://github.com/punica-ai/punica + +from typing import Optional + +import torch + + +def _raise_import_error(e): + if torch.cuda.get_device_capability() < (8, 0): + raise ImportError( + "punica LoRA kernels require compute capability >= 8.0") from e + else: + raise ImportError( + "punica LoRA kernels could not be imported. If you built vLLM " + "from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var " + "was set.") from e + + +def bgmv( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, +): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight + matrices. + indicies: Shape: `[B]`. Indices of the weight matrices. + layer_idx: Layer index of the weight matrices. + scale: Scaling factor. + """ + try: + import vllm._punica_C as punica_kernels + except ImportError as e: + _raise_import_error(e) + + punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + + +def add_lora(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + *, + buffer: Optional[torch.Tensor] = None): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + buffer: Optional. Shape: `[B, R]`. Temporary buffer. + """ + try: + import vllm._punica_C as punica_kernels + except ImportError as e: + _raise_import_error(e) + + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical inaccuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0) + punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, + scale) + + +def add_lora_slice(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + y_offset: int, + y_slice_size: int, + *, + buffer: Optional[torch.Tensor] = None): + """ + Same as `add_lora` but you can operate on slices of y. + Pass whole y, define y_offset and y_slice_size. + + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + y_offset: Offset to apply to the starting column of y. + y_slice_size: Size of the y column slice. + """ + try: + import vllm._punica_C as punica_kernels + except ImportError as e: + _raise_import_error(e) + + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical inaccuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + punica_kernels.dispatch_bgmv_low_level( + buffer, + x, + wa_t_all, + indicies, + layer_idx, + 1.0, + x.size(1), + buffer.size(1), + 0, + ) + punica_kernels.dispatch_bgmv_low_level( + y, + buffer, + wb_t_all, + indicies, + layer_idx, + scale, + buffer.size(1), + y_slice_size, + y_offset, + ) diff --git a/vllm/lora/request.py b/vllm/lora/request.py new file mode 100644 index 0000000..bbbf488 --- /dev/null +++ b/vllm/lora/request.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass + + +@dataclass +class LoRARequest: + """ + Request for a LoRA adapter. + + Note that this class should be be used internally. For online + serving, it is recommended to not allow users to use this class but + instead provide another layer of abstraction to prevent users from + accessing unauthorized LoRA adapters. + + lora_int_id must be globally unique for a given adapter. + This is currently not enforced in vLLM. + """ + + lora_name: str + lora_int_id: int + lora_local_path: str + + def __post_init__(self): + if self.lora_int_id < 1: + raise ValueError( + f"lora_int_id must be > 0, got {self.lora_int_id}") + + def __eq__(self, value: object) -> bool: + return isinstance( + value, LoRARequest) and self.lora_int_id == value.lora_int_id + + def __hash__(self) -> int: + return self.lora_int_id diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py new file mode 100644 index 0000000..f67a381 --- /dev/null +++ b/vllm/lora/utils.py @@ -0,0 +1,39 @@ +import logging +from typing import Tuple + +from torch import nn + +logger = logging.getLogger(__name__) + + +def replace_submodule(model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + +def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: + """Parse the name of lora weights. + + args: + name: the name of the fine-tuned LoRA, e.g. + base_model.model.dense1.weight + return: + Tuple(module_name, is_lora_a): + module_name: the name of the module, e.g. model.dense1, + is_lora_a whether the tensor is lora_a or lora_b. + """ + parts = name.split(".") + assert parts[0] == "base_model" + assert parts[1] == "model" + if parts[-1] == "weight": + assert parts[-2] == "lora_A" or parts[-2] == "lora_B" + return ".".join(parts[2:-2]), parts[-2] == "lora_A" + + if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" + + raise ValueError(f"{name} is unsupported format") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py new file mode 100644 index 0000000..7e92bc9 --- /dev/null +++ b/vllm/lora/worker_manager.py @@ -0,0 +1,238 @@ +import logging +from abc import ABC, abstractmethod, abstractproperty +from typing import Any, Dict, List, Optional, Set, Type + +import torch + +from vllm.lora.models import (LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, create_lora_manager) +from vllm.lora.request import LoRARequest +from vllm.lora.layers import LoRAMapping +from vllm.config import LoRAConfig + +logger = logging.getLogger(__name__) + + +class AbstractWorkerLoRAManager(ABC): + """Abstract class for managing LoRA models on the worker side.""" + + def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, + vocab_size: int, lora_config: LoRAConfig, + device: torch.device): + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.device = device + self.lora_config = lora_config + + @abstractproperty + def is_enabled(self) -> bool: + ... + + @abstractmethod + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + ... + + @abstractmethod + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + ... + + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + ... + + @abstractmethod + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + ... + + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + ... + + @abstractmethod + def remove_all_loras(self) -> bool: + ... + + @abstractmethod + def list_loras(self) -> Set[int]: + ... + + +class WorkerLoRAManager(AbstractWorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Every request, the requested LoRAs will be loaded (unless they are already + loaded), and every other LoRA will be unloaded.""" + + _lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + embedding_modules: Dict[str, str], + embedding_padding_modules: List[str], + lora_model_cls: Type[LoRAModel] = LoRAModel, + ): + self._lora_manager: Optional[LoRAModelManager] = None + self._lora_model_cls = lora_model_cls + self.embedding_modules = embedding_modules + self.embedding_padding_modules = embedding_padding_modules + super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, + lora_config, device) + + @property + def is_enabled(self) -> bool: + return True + + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + lora_manager = create_lora_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + lora_manager_cls=self._lora_manager_cls, + ) + self._lora_manager: LoRAModelManager = lora_manager + return lora_manager.model + + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + self._apply_loras(lora_requests) + self._lora_manager.set_lora_mapping(lora_mapping) + + def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + loras_that_exist = self.list_loras() + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._lora_manager.lora_slots: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._lora_manager.lora_slots}).") + + new_loras = set(loras_map) + loras_to_add = new_loras - loras_that_exist + loras_to_remove = loras_that_exist - new_loras + + for lora_id in loras_to_remove: + self.remove_lora(lora_id) + + for lora_id in loras_to_add: + self.add_lora(loras_map[lora_id]) + + def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: + try: + lora = self._lora_model_cls.from_local_checkpoint( + lora_request.lora_local_path, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + embedding_modules=self.embedding_modules, + embedding_padding_modules=self.embedding_padding_modules, + ) + except Exception as e: + raise RuntimeError( + f"Loading lora {lora_request.lora_local_path} failed") from e + if lora.rank > self.lora_config.max_lora_rank: + raise ValueError( + f"LoRA rank {lora.rank} is greater than max_lora_rank " + f"{self.lora_config.max_lora_rank}.") + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError( + f"LoRA added vocab size {lora.extra_vocab_size} is greater than " + f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}." + ) + return lora + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + if lora_request.lora_int_id in self.list_loras(): + return False + return self._lora_manager.add_lora( + self._lora_manager.create_dummy_lora(lora_request.lora_int_id, + rank, self.embedding_modules)) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id in self.list_loras(): + return False + lora = self._load_lora(lora_request) + loaded = self._lora_manager.add_lora(lora) + self._lora_manager.activate_lora(lora.id) + return loaded + + def remove_lora(self, lora_id: int) -> bool: + return self._lora_manager.remove_lora(lora_id) + + def remove_all_loras(self) -> bool: + self._lora_manager.remove_all_loras() + + def list_loras(self) -> Set[int]: + return set(self._lora_manager.list_loras()) + + +class LRUCacheWorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Uses an LRU Cache. Every request, the requested LoRAs will be loaded + (unless they are already loaded) and least recently used LoRAs will + be unloaded if the cache is above capacity.""" + + _lora_manager_cls: Type[ + LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + + def create_lora_manager( + self, + model: torch.nn.Module, + ) -> Any: + lora_manager = create_lora_manager( + model, + lora_manager_cls=self._lora_manager_cls, + max_num_seqs=self.max_num_seqs, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + max_num_batched_tokens=self.max_num_batched_tokens, + ) + self._lora_manager: LRUCacheLoRAModelManager = lora_manager + return lora_manager.model + + def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._lora_manager.lora_slots: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._lora_manager.lora_slots}).") + for lora in loras_map.values(): + self.add_lora(lora) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_loras(): + # Remove before we load the new lora to save memory + if len(self._lora_manager) + 1 > self._lora_manager.capacity: + self._lora_manager.remove_oldest_lora() + lora = self._load_lora(lora_request) + loaded = self._lora_manager.add_lora(lora) + else: + # If the lora is already loaded, just touch it to + # update its position in the caches + loaded = self._lora_manager.get_lora(lora_request.lora_int_id) + self._lora_manager.activate_lora(lora_request.lora_int_id) + return loaded diff --git a/vllm/model_executor/__init__.py b/vllm/model_executor/__init__.py new file mode 100644 index 0000000..17a33ee --- /dev/null +++ b/vllm/model_executor/__init__.py @@ -0,0 +1,10 @@ +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_random_seed, get_model + +__all__ = [ + "InputMetadata", + "get_model", + "SamplingMetadata", + "set_random_seed", +] \ No newline at end of file diff --git a/vllm/model_executor/guided_decoding.py b/vllm/model_executor/guided_decoding.py new file mode 100644 index 0000000..a8573f8 --- /dev/null +++ b/vllm/model_executor/guided_decoding.py @@ -0,0 +1,99 @@ +import asyncio +import concurrent.futures +from copy import copy +from enum import Enum +from functools import lru_cache +from json import dumps as json_dumps +from re import escape as regex_escape +from typing import Union, Tuple +from pydantic import BaseModel + +from vllm.entrypoints.openai.protocol import CompletionRequest, ChatCompletionRequest +from vllm.model_executor.guided_logits_processors import JSONLogitsProcessor, RegexLogitsProcessor + + +class GuidedDecodingMode(Enum): + JSON = "json" + REGEX = "regex" + CHOICE = "choice" + + +global_thread_pool = None # used for generating logits processor fsm + + +async def get_guided_decoding_logits_processor( + request: Union[CompletionRequest, ChatCompletionRequest], + tokenizer) -> Union[JSONLogitsProcessor, RegexLogitsProcessor]: + """ + Given an OpenAI-compatible request, check for guided decoding parameters + and get the necessary logits processor for the given guide. + We cache logit processors by (guide, tokenizer), and on cache hit + we make a shallow copy to reuse the same underlying FSM. + """ + global global_thread_pool + guide, mode = _get_guide_and_mode(request) + if not guide: + return None + + if global_thread_pool is None: + global_thread_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=2) + loop = asyncio.get_running_loop() + + result = await loop.run_in_executor(global_thread_pool, + _get_cached_logits_processor, guide, + tokenizer, mode) + + logits_processor = copy(result) + # reset logits processor's internal state + logits_processor.init_state() + return logits_processor + + +def _get_guide_and_mode( + request: Union[CompletionRequest, ChatCompletionRequest] +) -> Tuple[str, GuidedDecodingMode]: + + if request.guided_json: + if not isinstance(request.guided_json, (str, dict, BaseModel)): + raise TypeError("JSON schema must be str, dict, or BaseModel") + + json = request.guided_json + if isinstance(json, dict): + # turn dict into hashable string + json = json_dumps(json, sort_keys=True) + elif isinstance(json, BaseModel): + # use pydantic signature so that different model classes + # with the same fields will get hashed the same + json = str(json.__signature__) + return json, GuidedDecodingMode.JSON + + elif request.guided_regex: + if not isinstance(request.guided_regex, str): + raise TypeError("Regex must be string") + return request.guided_regex, GuidedDecodingMode.REGEX + + elif request.guided_choice: + if not isinstance(request.guided_choice, list): + raise TypeError("Choices must be a list") + + # choice just uses regex + choices = [ + regex_escape(str(choice)) for choice in request.guided_choice + ] + choices_regex = "(" + "|".join(choices) + ")" + return choices_regex, GuidedDecodingMode.CHOICE + + else: + return None, None + + +@lru_cache(maxsize=32) +def _get_cached_logits_processor(guide: str, tokenizer, + mode: GuidedDecodingMode): + if mode == GuidedDecodingMode.JSON: + return JSONLogitsProcessor(guide, tokenizer) + elif mode == GuidedDecodingMode.REGEX or mode == GuidedDecodingMode.CHOICE: + return RegexLogitsProcessor(guide, tokenizer) + else: + raise ValueError(f"Unknown guided decoding mode {mode}") diff --git a/vllm/model_executor/guided_logits_processors.py b/vllm/model_executor/guided_logits_processors.py new file mode 100644 index 0000000..1b3e5e7 --- /dev/null +++ b/vllm/model_executor/guided_logits_processors.py @@ -0,0 +1,129 @@ +# Copyright 2024- the Outlines developers +# This file is adapted from +# https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import math +from collections import defaultdict +from typing import Union, DefaultDict, Dict, List, Optional + +import torch +from pydantic import BaseModel +from outlines.fsm.fsm import RegexFSM +from outlines.fsm.json_schema import build_regex_from_schema + + +class RegexLogitsProcessor: + + def __init__(self, regex_string: str, tokenizer): + """Compile the FSM that drives the regex-structured generation. + + Parameters + ---------- + regex_string + A string that represents a regular expression + tokenizer + The model's tokenizer + + """ + tokenizer = self.adapt_tokenizer(tokenizer) + fsm = RegexFSM(regex_string, tokenizer) + self.fsm = fsm + + def init_state(self): + """Initialize the FSM states.""" + self.fsm_state: DefaultDict[int, int] = defaultdict(int) + + def __call__(self, input_ids: List[int], + scores: torch.Tensor) -> torch.Tensor: + """Use the FSM to bias the logits before sampling the next token.""" + + seq_id = hash(tuple(input_ids)) + + if len(input_ids) == 0: + self.init_state() + else: + last_token = input_ids[-1] + last_seq_id = hash(tuple(input_ids[:-1])) + self.fsm_state[seq_id] = self.fsm.next_state( + self.fsm_state[last_seq_id], last_token) + + allowed_tokens = self.fsm.allowed_token_ids(self.fsm_state[seq_id]) + + mask = torch.full((scores.shape[-1], ), + -math.inf, + device=scores.device) + mask[allowed_tokens] = 0 + scores.add_(mask) + + return scores + + def adapt_tokenizer(self, tokenizer): + """Adapt vLLM's tokenizer to use to compile the FSM. + + The API of Outlines tokenizers is slightly different to that of + `transformers`. In addition we need to handle the missing spaces to + Llama's tokenizer to be able to compile FSMs for this model. + + """ + tokenizer.vocabulary = tokenizer.get_vocab() + tokenizer.special_tokens = set(tokenizer.all_special_tokens) + + def convert_token_to_string(token: str) -> str: + from transformers.file_utils import SPIECE_UNDERLINE + + string = tokenizer.convert_tokens_to_string([token]) + + # A hack to handle missing spaces to HF's Llama tokenizers + if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>": + return " " + string + + return string + + tokenizer.convert_token_to_string = convert_token_to_string + + return tokenizer + + +class JSONLogitsProcessor(RegexLogitsProcessor): + + def __init__(self, + schema: Union[str, Dict, BaseModel], + tokenizer, + whitespace_pattern: Optional[str] = None): + """Compile the FSM that drives the JSON-guided generation. + + Parameters + ---------- + schema + A JSON schema that encodes the structure we want the model to generate + tokenizer + The model's tokenizer + whitespace_pattern + Pattern to use for JSON syntactic whitespace (doesn't impact string literals) + Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` + """ + if isinstance(schema, type(BaseModel)): + schema_str = json.dumps(schema.model_json_schema()) + elif isinstance(schema, Dict): + schema_str = json.dumps(schema) + elif isinstance(schema, str): + schema_str = schema + else: + raise ValueError( + f"Cannot parse schema {schema}. The schema must be either " + + "a Pydantic object, a dictionary or a string that contains the JSON " + + "Schema specification") + regex_string = build_regex_from_schema(schema_str, whitespace_pattern) + super().__init__(regex_string, tokenizer) diff --git a/vllm/model_executor/input_metadata.py b/vllm/model_executor/input_metadata.py new file mode 100644 index 0000000..f0a88ac --- /dev/null +++ b/vllm/model_executor/input_metadata.py @@ -0,0 +1,54 @@ +from typing import Optional + +import torch + + +class InputMetadata: + """Metadata for input sequences. Used in PagedAttention. + + Args: + prompt_lens: Lengths of prompts. + slot_mapping: The address to write the new KV to of each token. + max_context_len: The maximum context length. + context_lens: the length of attention context for each sequence. + block_tables: The block tables. (Seq id -> list of physical block) + kv_cache_dtype: Data type to store kv cache. + """ + + def __init__( + self, + is_prompt: bool, + slot_mapping: torch.Tensor, + prompt_lens: Optional[torch.Tensor], + max_seq_len: Optional[int], + start_loc: Optional[torch.Tensor], + max_context_len: Optional[int], + context_lens: Optional[torch.Tensor], + block_tables: Optional[torch.Tensor], + use_cuda_graph: bool, + kv_cache_dtype: str, + ) -> None: + self.is_prompt = is_prompt + self.prompt_lens = prompt_lens + self.max_seq_len = max_seq_len + self.start_loc = start_loc + self.max_context_len = max_context_len + self.slot_mapping = slot_mapping + self.context_lens = context_lens + self.block_tables = block_tables + self.use_cuda_graph = use_cuda_graph + self.kv_cache_dtype = kv_cache_dtype + + # Set during the execution of the first attention op. + # FIXME(woosuk): This is a hack. + self.attn_bias = None + + def __repr__(self) -> str: + return ("InputMetadata(" + f"is_prompt={self.is_prompt}, " + f"max_context_len={self.max_context_len}, " + f"slot_mapping={self.slot_mapping}, " + f"context_lens={self.context_lens}, " + f"block_tables={self.block_tables}, " + f"use_cuda_graph={self.use_cuda_graph}, " + f"kv_cache_dtype={self.kv_cache_dtype})") diff --git a/vllm/model_executor/layers/__init__.py b/vllm/model_executor/layers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py new file mode 100644 index 0000000..7e41f8f --- /dev/null +++ b/vllm/model_executor/layers/activation.py @@ -0,0 +1,237 @@ +"""Custom activation functions.""" +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from vllm._C import ops +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.utils import divide +from vllm.model_executor.utils import set_weight_attrs + + +class SiluAndMul(nn.Module): + """An activation function for SwiGLU. + + The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) + return: (batch_size, seq_len, d) or (num_tokens, d) + """ + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + ops.silu_and_mul(out, x) + return out + + +class GeluAndMul(nn.Module): + """An activation function for GeGLU. + + The function computes x -> GELU(x[:d]) * x[d:] where d = x.shape[-1] // 2. + + Shapes: + x: (batch_size, seq_len, 2 * d) or (num_tokens, 2 * d) + return: (batch_size, seq_len, d) or (num_tokens, d) + """ + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + d = x.shape[-1] // 2 + return F.gelu(x[..., :d]) * x[..., d:] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = (x.shape[:-1] + (d, )) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + ops.gelu_and_mul(out, x) + return out + + +class NewGELU(nn.Module): + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + c = math.sqrt(2.0 / math.pi) + return 0.5 * x * (1.0 + torch.tanh(c * + (x + 0.044715 * torch.pow(x, 3.0)))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + ops.gelu_new(out, x) + return out + + +class FastGELU(nn.Module): + + def _forward(self, x: torch.Tensor) -> torch.Tensor: + """PyTorch-native implementation equivalent to forward().""" + return 0.5 * x * (1.0 + torch.tanh(x * 0.7978845608 * + (1.0 + 0.044715 * x * x))) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + ops.gelu_fast(out, x) + return out + + +class ScaledActivation(nn.Module): + """An activation function with post-scale parameters. + + This is used for some quantization methods like AWQ. + """ + + def __init__( + self, + act_module: nn.Module, + intermediate_size: int, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.act = act_module + self.input_is_parallel = input_is_parallel + if input_is_parallel: + tp_size = get_tensor_model_parallel_world_size() + intermediate_size_per_partition = divide(intermediate_size, + tp_size) + else: + intermediate_size_per_partition = intermediate_size + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.scales = nn.Parameter( + torch.empty(intermediate_size_per_partition, dtype=params_dtype)) + set_weight_attrs(self.scales, {"weight_loader": self.weight_loader}) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.act(x) / self.scales + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor): + param_data = param.data + if self.input_is_parallel: + tp_rank = get_tensor_model_parallel_rank() + shard_size = param_data.shape[0] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(0, start_idx, shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +_ACTIVATION_REGISTRY = { + "gelu": nn.GELU(), + "gelu_fast": FastGELU(), + "gelu_new": NewGELU(), + "gelu_pytorch_tanh": nn.GELU(approximate="tanh"), + "relu": nn.ReLU(), +} + + +def get_act_fn( + act_fn_name: str, + quant_config: Optional[QuantizationConfig] = None, + intermediate_size: Optional[int] = None, + input_is_parallel: bool = True, + params_dtype: Optional[torch.dtype] = None, +) -> nn.Module: + """Get an activation function by name.""" + act_fn_name = act_fn_name.lower() + if act_fn_name not in _ACTIVATION_REGISTRY: + raise ValueError( + f"Activation function {act_fn_name!r} is not supported.") + + act_fn = _ACTIVATION_REGISTRY[act_fn_name] + if (quant_config is not None + and act_fn_name in quant_config.get_scaled_act_names()): + if intermediate_size is None: + raise ValueError("intermediate_size must be specified for scaled " + "activation functions.") + return ScaledActivation(act_fn, intermediate_size, input_is_parallel, + params_dtype) + return act_fn + + +# ↓ add for smoothquant +class DequantSiluAndMulQuant(nn.Module): + """An activation function for SwiGLU. + The function computes x -> silu(x[:d]) * x[d:] where d = x.shape[1] // 2. + Shapes: + x: (num_tokens, 2 * d) + return: (num_tokens, d) + """ + + # TODO(Zhang Ying): use_per_token_quant + def __init__(self, + gate_dequant_scale: float = 1.0, + up_dequant_scale: float = 1.0, + quant_scale: float = 1.0, + use_per_token_quant: bool = True) -> None: + super().__init__() + self.register_parameter( + "gate_dequant_scale", + torch.nn.Parameter( + torch.tensor(gate_dequant_scale,dtype=torch.float32,requires_grad=False)) + ) + self.register_parameter( + "up_dequant_scale", + torch.nn.Parameter( + torch.tensor(up_dequant_scale,dtype=torch.float32,requires_grad=False)) + ) + self.register_parameter( + "quant_scale", + torch.nn.Parameter( + torch.tensor(quant_scale, dtype=torch.float32,requires_grad=False)) + ) + self.use_per_token_quant = use_per_token_quant + + def _apply(self, fn): + super()._apply(fn) + self.gate_dequant_scale.data = self.gate_dequant_scale.cpu() + self.up_dequant_scale.data = self.up_dequant_scale.cpu() + self.quant_scale.data = self.quant_scale.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.gate_dequant_scale.data = self.gate_dequant_scale.to(*args, **kwargs) + self.gate_dequant_scale.data = self.gate_dequant_scale.to(torch.float32) + self.up_dequant_scale.data = self.up_dequant_scale.to(*args, **kwargs) + self.up_dequant_scale.data = self.up_dequant_scale.to(torch.float32) + self.quant_scale.data = self.quant_scale.to(*args, **kwargs) + self.quant_scale.data = self.quant_scale.to(torch.float32) + return self + + def forward(self, x: torch.Tensor) -> torch.Tensor: + num_tokens = x.numel() // x.shape[-1] + d = x.shape[-1] // 2 + out = torch.empty(*x.shape[:-1], d, dtype=torch.int8, device=x.device) + if self.use_per_token_quant: + scale = torch.empty(num_tokens, + dtype=torch.float32, + device=x.device) + # tmp is used in kernel func + tmp = torch.empty(num_tokens, + d, + dtype=torch.float32, + device=x.device) + ops.dequant_silu_and_mul_quant( + out, x, self.gate_dequant_scale.item(), self.up_dequant_scale.item(), + scale, tmp) + return out, scale + else: + ops.dequant_silu_and_mul_quant( + out, x, self.gate_dequant_scale.item(), self.up_dequant_scale.item(), + self.quant_scale.item()) + return out + diff --git a/vllm/model_executor/layers/attention.py b/vllm/model_executor/layers/attention.py new file mode 100644 index 0000000..5c6231b --- /dev/null +++ b/vllm/model_executor/layers/attention.py @@ -0,0 +1,542 @@ +"""Multi-head attention.""" +import os +enable_infer_paged_attn = os.getenv("ENABLE_INFER_PAGED_ATTN",None) +from typing import List, Optional + +import importlib +import torch +import torch.nn as nn +from ixformer.contrib.xformers import ops as xops +from ixformer.contrib.xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, + LowerTriangularMaskWithTensorBias) + +from vllm._C import ops +from vllm._C import cache_ops +from vllm.model_executor.input_metadata import InputMetadata +## from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( +## context_attention_fwd) +from vllm.utils import is_hip + +# _SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256] +# # Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +# _PARTITION_SIZE = 512 +_SUPPORTED_HEAD_SIZES = [64, 128, 256] +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 256 + + +class PagedAttention(nn.Module): + """MHA/MQA/GQA layer with PagedAttention. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Reshape and store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention using either + xformers or the PagedAttention custom op. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + ) -> None: + super().__init__() + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads + self.sliding_window = sliding_window + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if self.head_size not in _SUPPORTED_HEAD_SIZES: + raise ValueError(f"head_size ({self.head_size}) is not supported. " + f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.") + + self.use_ref_attention = self.check_use_ref_attention() + + # TODO align vllm do not need those + self.attn_op = xops.fmha.flash.FwOp() + head_mapping = torch.repeat_interleave( + torch.arange(self.num_kv_heads, dtype=torch.int32), + self.num_queries_per_kv) + self.register_buffer("head_mapping", head_mapping, persistent=False) + + def check_use_ref_attention(self) -> bool: + if not is_hip(): + return False + # For ROCm, check whether flash attention is installed or not. + # if not, use_ref_attention needs to be True + return importlib.util.find_spec("flash_attn") is None + + def ref_masked_attention( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> torch.Tensor: + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + seq_len, _, _ = query.shape + attn_mask = torch.triu(torch.ones(seq_len, + seq_len, + dtype=query.dtype, + device=query.device), + diagonal=1) + attn_mask = attn_mask * torch.finfo(query.dtype).min + + attn_weights = self.scale * torch.einsum("qhd,khd->hqk", query, + key).float() + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ) -> torch.Tensor: + """PagedAttention forward pass. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] + input_metadata: metadata for the inputs. + cache_event: event to wait for the cache operations to finish. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + """ + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + slot_mapping = input_metadata.slot_mapping + + # Reshape the keys and values and store them in the cache. + # If key_cache and value_cache are not provided, the new key and value + # vectors will not be cached. This happens during the initial memory + # profiling run. + if key_cache is not None and value_cache is not None: + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping, + ) + + if input_metadata.is_prompt: + # normal attention + if (key_cache is None or value_cache is None + or input_metadata.block_tables.numel() == 0): + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens(input_metadata.prompt_lens) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = attn_bias + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(input_metadata.prompt_lens) + input_metadata.attn_bias = attn_bias + + if self.use_ref_attention: + output = self.ref_masked_attention( + query, + key, + value, + ) + # Using view got RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use reshape instead + return output.reshape(num_tokens, hidden_size) + + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=self.attn_op, + alibi_slopes=self.alibi_slopes + ) + output = out.view_as(query) + else: + # prefix-enabled attention + output = torch.empty_like(query) + context_attention_fwd( + query, + key, + value, + output, + key_cache, + value_cache, + input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.start_loc, + input_metadata.prompt_lens, + input_metadata.context_lens, + input_metadata.max_seq_len, + getattr(self, "alibi_slopes", None), + ) + else: + # Decoding run. + output = _paged_attention( + query, + key_cache, + value_cache, + input_metadata, + self.head_mapping, # self.num_kv_heads + self.scale, + self.alibi_slopes, + ) + + # Reshape the output tensor. + return output.view(num_tokens, hidden_size) + # TODO align + """ + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ) -> torch.Tensor: + PagedAttention forward pass. + + Args: + query: shape = [batch_size, seq_len, num_heads * head_size] + key: shape = [batch_size, seq_len, num_kv_heads * head_size] + value: shape = [batch_size, seq_len, num_kv_heads * head_size] + key_cache: shape = [num_blocks, num_kv_heads, head_size/x, + block_size, x] + value_cache: shape = [num_blocks, num_kv_heads, head_size, + block_size] + input_metadata: metadata for the inputs. + Returns: + shape = [batch_size, seq_len, num_heads * head_size] + + batch_size, seq_len, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + # Reshape the keys and values and store them in the cache. + # If key_cache and value_cache are not provided, the new key and value + # vectors will not be cached. This happens during the initial memory + # profiling run. + if key_cache is not None and value_cache is not None: + cache_ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + input_metadata.slot_mapping.flatten(), + input_metadata.kv_cache_dtype, + ) + + if input_metadata.is_prompt: + # normal attention + if (key_cache is None or value_cache is None + or input_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + # As of Nov 2023, xformers only supports MHA. For MQA/GQA, + # project the key and value tensors to the desired number of + # heads. + # TODO(woosuk): Use MQA/GQA kernels for higher performance. + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], + self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + # Set attention bias if not provided. This typically happens at + # the very attention layer of every iteration. + # FIXME(woosuk): This is a hack. + if input_metadata.attn_bias is None: + if self.alibi_slopes is None: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + [seq_len] * batch_size) + if self.sliding_window is not None: + attn_bias = attn_bias.make_local_attention( + self.sliding_window) + input_metadata.attn_bias = attn_bias + else: + input_metadata.attn_bias = _make_alibi_bias( + self.alibi_slopes, self.num_kv_heads, batch_size, + seq_len, query.dtype) + + if self.use_ref_attention: + output = self.ref_masked_attention( + query, + key, + value, + ) + # Using view got RuntimeError: view size is not compatible with input tensor's size and stride + # (at least one dimension spans across two contiguous subspaces). Use reshape instead + return output.reshape(batch_size, seq_len, hidden_size) + + # TODO(woosuk): Too many view operations. Let's try to reduce + # them in the future for code readability. + if self.alibi_slopes is None: + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + else: + query = query.unflatten(0, (batch_size, seq_len)) + key = key.unflatten(0, (batch_size, seq_len)) + value = value.unflatten(0, (batch_size, seq_len)) + + out = xops.memory_efficient_attention_forward( + query, + key, + value, + attn_bias=input_metadata.attn_bias, + p=0.0, + scale=self.scale, + op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if + (is_hip()) else None, + ) + output = out.view_as(query) + else: + # prefix-enabled attention + output = torch.empty_like(query) + context_attention_fwd( + query, + key, + value, + output, + key_cache, + value_cache, + input_metadata.block_tables, # [BS, max_block_per_request] + input_metadata.start_loc, + input_metadata.prompt_lens, + input_metadata.context_lens, + input_metadata.max_seq_len, + getattr(self, "alibi_slopes", None), + ) + + else: + # Decoding run. + output = _paged_attention( + query, + key_cache, + value_cache, + input_metadata, + self.num_kv_heads, + self.scale, + self.alibi_slopes, + ) + + # Reshape the output tensor. + return output.view(batch_size, seq_len, hidden_size) + """ + + +def _make_alibi_bias( + alibi_slopes: torch.Tensor, + num_kv_heads: int, + batch_size: int, + seq_len: int, + dtype: torch.dtype, +) -> LowerTriangularMaskWithTensorBias: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(prompt_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + # When using custom attention bias, xformers requires the bias to + # be sliced from a tensor whose length is a multiple of 8. + padded_len = (seq_len + 7) // 8 * 8 + num_heads = alibi_slopes.shape[0] + bias = torch.empty( + batch_size, + num_heads, + seq_len, + padded_len, + device=alibi_slopes.device, + dtype=dtype, + )[:, :, :, :seq_len].copy_(bias) + bias.mul_(alibi_slopes[:, None, None]) + if num_heads != num_kv_heads: + bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads)) + attn_bias = LowerTriangularMaskWithTensorBias(bias) + return attn_bias + + +def _paged_attention( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + input_metadata: InputMetadata, + head_mapping: torch.Tensor, # num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + use_sqrt_alibi: bool = False +) -> torch.Tensor: + output = torch.empty_like(query) + + use_v2 = enable_infer_paged_attn is None and key_cache.dim() == 4 + if not use_v2: + block_size = value_cache.shape[3] + # Run PagedAttention V1. + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + head_mapping, # num_kv_heads + scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + input_metadata.kv_cache_dtype, + ) + else: + # Run PagedAttention V2. + block_size = value_cache.shape[2] + num_seqs, num_heads, head_size = query.shape + max_num_partitions = ( + (input_metadata.max_context_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + head_mapping, # num_kv_heads + scale, + input_metadata.block_tables, + input_metadata.context_lens, + block_size, + input_metadata.max_context_len, + alibi_slopes, + input_metadata.kv_cache_dtype, + ) + return output + + +# ↓ add for smoothquant +class DequantPagedAttention(PagedAttention): + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + quant_kv_cache: bool = False, + kv_quant_params: torch.Tensor = None, + quant_scale: float = 1.0, + use_per_token_quant: bool = True, + ) -> None: + super().__init__(num_heads, + head_size, + scale, + num_kv_heads, + alibi_slopes, + sliding_window) + self.register_parameter( + "quant_scale", + torch.nn.Parameter( + torch.tensor(quant_scale, dtype=torch.float32,requires_grad=False)) + ) + self.use_per_token_quant = use_per_token_quant + + def _apply(self, fn): + super()._apply(fn) + self.quant_scale.data = self.quant_scale.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.quant_scale.data = self.quant_scale.to(*args, **kwargs) + self.quant_scale.data = self.quant_scale.to(torch.float32) + return self + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_cache: Optional[torch.Tensor], + value_cache: Optional[torch.Tensor], + input_metadata: InputMetadata, + ) -> torch.Tensor: + out = super().forward( + query, + key, + value, + key_cache, + value_cache, + input_metadata, + ) + quant_out = torch.empty_like(out, dtype=torch.int8) + if self.use_per_token_quant: + scale = torch.empty(out.numel() // out.shape[-1], + dtype=torch.float32, + device=out.device) + ops.quant(quant_out, out, scale) + return quant_out, scale + else: + ops.quant(quant_out, out, self.quant_scale.item()) + return (quant_out, ) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py new file mode 100644 index 0000000..6c4fd98 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -0,0 +1,5 @@ +from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe + +__all__ = [ + "fused_moe", +] \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py new file mode 100644 index 0000000..08e3c2d --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -0,0 +1,377 @@ +"""Fused MoE kernel.""" +import functools +import json +import os +from typing import Any, Dict, Optional, Tuple + +import torch +import triton +import triton.language as tl + +from vllm._C import ops +from vllm.logger import init_logger +from vllm.utils import is_hip + +logger = init_logger(__name__) + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated, + and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to. + - expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A. + This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids` + by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the K dimension. + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +def moe_align_block_size( + topk_ids: torch.Tensor, block_size: int, + num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations. + """ + sorted_ids = torch.empty( + (topk_ids.numel() + num_experts * (block_size - 1), ), + dtype=torch.int32, + device=topk_ids.device) + expert_ids = torch.empty((topk_ids.numel() + num_experts, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, num_tokens_post_pad) + return sorted_ids, expert_ids, num_tokens_post_pad + + +def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, + config: Dict[str, Any]) -> None: + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) + + fused_moe_kernel[grid]( + A, + B, + C, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2], + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16, + **config, + ) + + +@functools.lru_cache +def get_moe_configs(E: int, N: int) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the fused MoE kernel. + + The return value will be a dictionary that maps an irregular grid of batch sizes + to configurations of the fused_moe kernel. To evaluate the kernel on a given batch + size bs, the closest batch size in the grid should be picked and the associated + configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs directory + device_name = torch.cuda.get_device_name().replace(" ", "_") + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", + f"E={E},N={N},device_name={device_name}.json") + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info( + f"Using configuration from {config_file_path} for MoE layer.") + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default configuration + return None + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override for the kernel configuration. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + M, _ = hidden_states.shape + E, N, _ = w1.shape + + if is_hip(): + # The MoE kernels are not yet supported on ROCm. + routing_weights = torch.softmax(gating_output, + dim=-1, + dtype=torch.float32) + topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1) + else: + import vllm._moe_C as moe_kernels + + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + moe_kernels.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + ) + del token_expert_indicies # Not used. Will be used in the future. + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + if override_config: + config = override_config + else: + # First try to load optimal config from the file + configs = get_moe_configs(E, w2.shape[2]) + + if configs: + # If an optimal configuration map has been found, look up the optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + } + + if M <= E: + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 1 + } + + intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + + sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( + topk_ids, config['BLOCK_SIZE_M'], E) + + invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, + topk_weights, topk_ids, sorted_token_ids, + expert_ids, num_tokens_post_padded, False, + topk_ids.shape[1], config) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, + topk_weights, topk_ids, sorted_token_ids, + expert_ids, num_tokens_post_padded, True, 1, + config) + + if inplace: + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1, + out=hidden_states) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py new file mode 100644 index 0000000..6ee0a5b --- /dev/null +++ b/vllm/model_executor/layers/layernorm.py @@ -0,0 +1,216 @@ +"""Custom normalization layers.""" +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from vllm._C import ops + + +class RMSNorm(nn.Module): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def _forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """PyTorch-native implementation equivalent to forward().""" + orig_dtype = x.dtype + x = x.to(torch.float32) + if residual is not None: + x = x + residual.to(torch.float32) + residual = x.to(orig_dtype) + + variance = x.pow(2).mean(dim=-1, keepdim=True) + x = x * torch.rsqrt(variance + self.variance_epsilon) + x = x.to(orig_dtype) * self.weight + if residual is None: + return x + else: + return x, residual + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + scale: float = 1.0, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + ops.fused_add_rms_norm( + x, + residual, + self.weight.data, + self.variance_epsilon, + scale, + ) + return x, residual + out = torch.empty_like(x) + ops.rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + return out + + +# ↓ add for smoothquant +class RMSNormQuant(nn.Module): + """Root mean square normalization. + + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + ) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward( + self, + x: torch.Tensor, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + out = torch.empty_like(x, dtype=torch.int8) + ops.rms_norm_quant( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + return out + + +class AddResidualRMSNormQuant(nn.Module): + """Root mean square normalization. + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + def __init__(self, + hidden_size: int, + eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, + x: torch.Tensor, + residual: torch.Tensor, + scale: torch.Tensor = None) -> torch.Tensor: + out = torch.empty_like(x, dtype=torch.int8) + ops.fused_add_rms_norm_quant(out, x, residual, self.weight.data, self.variance_epsilon) + return out, residual + + +class DequantAddResidualRMSNormQuant(nn.Module): + """Root mean square normalization. + Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight. + Refer to https://arxiv.org/abs/1910.07467 + """ + + # TODO(Zhang Ying): use_per_token_dequant + def __init__(self, + hidden_size: int, + dequant_scale: float = 1.0, + use_per_token_dequant: bool = True, + eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + self.register_parameter( + "dequant_scale", + torch.nn.Parameter(torch.tensor(dequant_scale,dtype=torch.float32,requires_grad=False)) + ) + self.use_per_token_dequant = use_per_token_dequant + + def _apply(self, fn): + super()._apply(fn) + self.dequant_scale.data = self.dequant_scale.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.dequant_scale.data = self.dequant_scale.to(*args, **kwargs) + self.dequant_scale.data = self.dequant_scale.to(torch.float32) + return self + + def forward(self, + x: torch.Tensor, + residual: torch.Tensor, + scale: torch.Tensor = None) -> torch.Tensor: + out = torch.empty_like(x, dtype=torch.int8) + if self.use_per_token_dequant and scale is not None: + ops.dequant_fused_add_rms_norm_quant( + out, x, residual, self.weight.data,self.variance_epsilon, + scale, self.dequant_scale.item()) + else: + ops.dequant_fused_add_rms_norm_quant( + out, x, residual, self.weight.data, self.variance_epsilon, + None, self.dequant_scale.item()) + return out, residual + + +class DequantAddResidual(nn.Module): + def __init__(self, + dequant_scale: float = 1.0, + use_per_token_dequant: bool = True) -> None: + super().__init__() + self.register_parameter( + "dequant_scale", + torch.nn.Parameter(torch.tensor(dequant_scale,dtype=torch.float32,requires_grad=False)) + ) + self.use_per_token_dequant = use_per_token_dequant + + def _apply(self, fn): + super()._apply(fn) + self.dequant_scale.data = self.dequant_scale.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.dequant_scale.data = self.dequant_scale.to(*args, **kwargs) + self.dequant_scale.data = self.dequant_scale.to(torch.float32) + return self + + def forward(self, + x: torch.Tensor, + residual: torch.Tensor, + scale: torch.Tensor = None) -> torch.Tensor: + out = torch.empty_like(residual) + if self.use_per_token_dequant and scale is not None: + ops.dequant_add_residual(out, x, residual, scale, self.dequant_scale.item()) + else: + ops.dequant_add_residual(out, x, residual, None, self.dequant_scale.item()) + return out + + +class AddResidual(DequantAddResidual): + def __init__(self, + dequant_scale: float = 1.0, + use_per_token_dequant: bool = True): + super().__init__(dequant_scale,use_per_token_dequant) + + def forward(self, + x: torch.Tensor, + residual: torch.Tensor, + scale: torch.Tensor = None) -> torch.Tensor: + return x + residual diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py new file mode 100644 index 0000000..915e04d --- /dev/null +++ b/vllm/model_executor/layers/linear.py @@ -0,0 +1,754 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +import torch +import ixformer.functions as F +from torch.nn.parameter import Parameter + +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather) +from vllm.model_executor.parallel_utils.utils import ( + divide, split_tensor_along_last_dim) +from vllm.model_executor.utils import set_weight_attrs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def adjust_marlin_shard(param, shard_size, shard_offset): + marlin_tile_size = getattr(param, "marlin_tile_size", None) + if marlin_tile_size is None: + return shard_size, shard_offset + + return shard_size * marlin_tile_size, shard_offset * marlin_tile_size + + +class LinearMethodBase(ABC): + """Base class for different (maybe quantized) linear methods.""" + + @abstractmethod + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + """Create weights for a linear layer.""" + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + weights: Dict[str, torch.Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Apply the weights to the input tensor.""" + raise NotImplementedError + + +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization. + + Args: + separate_bias_add: If true, add bias separately after matrix + multiplication. + """ + + def __init__(self, separate_bias_add: bool = True): + self.separate_bias_add = separate_bias_add + + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + return {"weight": weight} + + def apply_weights(self, + weights: Dict[str, torch.Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = weights["weight"] + if self.separate_bias_add: + if bias is not None: + return F.linear(x, weight) + bias + return F.linear(x, weight) + return F.linear(x, weight, bias) + + +class ReplicatedLinear(torch.nn.Module): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if linear_method is None: + linear_method = UnquantizedLinearMethod() + self.linear_method = linear_method + self.linear_weights = self.linear_method.create_weights( + self.input_size, self.output_size, self.input_size, + self.output_size, self.params_dtype) + for name, weight in self.linear_weights.items(): + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=self.params_dtype)) + set_weight_attrs(self.bias, {"output_dim": 0}) + else: + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bias = self.bias if not self.skip_bias_add else None + output = self.linear_method.apply_weights(self.linear_weights, x, bias) + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Args: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, tp_size) + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if linear_method is None: + linear_method = UnquantizedLinearMethod() + self.linear_method = linear_method + self.linear_weights = self.linear_method.create_weights( + self.input_size, self.output_size_per_partition, self.input_size, + self.output_size, self.params_dtype) + for name, weight in self.linear_weights.items(): + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) + set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + output_dim = getattr(param, "output_dim", None) + param_data = param.data + if output_dim is not None: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + output_parallel = self.linear_method.apply_weights( + self.linear_weights, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class MergedColumnParallelLinear(ColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ + + def __init__( + self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + self.output_sizes = output_sizes + tp_size = get_tensor_model_parallel_world_size() + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__(input_size, sum(output_sizes), bias, gather_output, + skip_bias_add, params_dtype, linear_method) + + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + param_data = param.data + output_dim = getattr(param, "output_dim", None) + if loaded_shard_id is None: + # Loaded weight is already packed. + if output_dim is None: + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + shard_offsets = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # If marlin, we need to adjust the offset and size to account for the tiling. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # If marlin, we need to adjust the offset and size to account for the tiling. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class QKVParallelLinear(ColumnParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, + self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size + super().__init__(input_size, output_size, bias, False, skip_bias_add, + params_dtype, linear_method) + + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + param_data = param.data + output_dim = getattr(param, "output_dim", None) + + if loaded_shard_id is None: + # Loaded weight is already packed. + if output_dim is None: + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ("k", self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + ("v", (self.total_num_heads + self.total_num_kv_heads) * + self.head_size, self.total_num_kv_heads * self.head_size), + ] + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # If marlin, we need to adjust the offset and size to account for the tiling. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + tp_rank = get_tensor_model_parallel_rank() + assert loaded_shard_id in ["q", "k", "v"] + if output_dim is not None: + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_id == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = (self.num_heads + + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.head_size + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + + # If marlin, we need to adjust the offset and size to account for the tiling. + shard_size, shard_offset = adjust_marlin_shard( + param, shard_size, shard_offset) + + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + if loaded_shard_id == "q": + shard_id = tp_rank + else: + shard_id = tp_rank // self.num_kv_head_replicas + start_idx = shard_id * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + else: + ignore_warning = getattr(param, "ignore_warning", False) + if not ignore_warning: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions.") + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + # Divide the weight matrix along the last dimension. + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + self.skip_bias_add = skip_bias_add + if linear_method is None: + linear_method = UnquantizedLinearMethod() + self.linear_method = linear_method + self.linear_weights = self.linear_method.create_weights( + self.input_size_per_partition, self.output_size, self.input_size, + self.output_size, self.params_dtype) + for name, weight in self.linear_weights.items(): + if isinstance(weight, torch.Tensor): + self.register_parameter(name, weight) + set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + input_dim = getattr(param, "input_dim", None) + param_data = param.data + if input_dim is not None: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, + shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward(self, input_): + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.linear_method.apply_weights( + self.linear_weights, input_parallel) + if self.reduce_results and self.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias + + +# ↓ add for smoothquant +class QuantMergedColumnParallelLinear(MergedColumnParallelLinear): + + def __init__( + self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + dequant_scale: float = 1.0, + ): + super().__init__(input_size,output_sizes,bias,gather_output, + skip_bias_add,params_dtype,linear_method) + self.register_parameter("dequant_scale", + torch.nn.Parameter( + torch.tensor(dequant_scale,dtype=torch.float32,requires_grad=False)) + ) + + def _apply(self, fn): + super()._apply(fn) + self.dequant_scale.data = self.dequant_scale.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.dequant_scale.data = self.dequant_scale.to(*args, **kwargs) + self.dequant_scale.data = self.dequant_scale.to(torch.float32) + return self + + def forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + output_parallel = self.linear_method.apply_weights( + self.linear_weights, input_, bias, scale=None, dequant_scale=1.0) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class QuantQKVParallelLinear(QKVParallelLinear): + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + q_dequant_scale: float = 1.0, + k_dequant_scale: float = 1.0, + v_dequant_scale: float = 1.0, + ): + super().__init__(hidden_size,head_size,total_num_heads,total_num_kv_heads, + bias,skip_bias_add,params_dtype,linear_method) + self.register_parameter( + "q_dequant_scale", + torch.nn.Parameter( + torch.tensor(q_dequant_scale,dtype=torch.float32,requires_grad=False)) + ) + self.register_parameter( + "k_dequant_scale", + torch.nn.Parameter( + torch.tensor(k_dequant_scale,dtype=torch.float32,requires_grad=False)) + ) + self.register_parameter( + "v_dequant_scale", + torch.nn.Parameter( + torch.tensor(v_dequant_scale,dtype=torch.float32,requires_grad=False)) + ) + + def _apply(self, fn): + super()._apply(fn) + self.q_dequant_scale.data = self.q_dequant_scale.cpu() + self.k_dequant_scale.data = self.k_dequant_scale.cpu() + self.v_dequant_scale.data = self.v_dequant_scale.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.q_dequant_scale.data = self.q_dequant_scale.to(*args, **kwargs) + self.q_dequant_scale.data = self.q_dequant_scale.to(torch.float32) + self.k_dequant_scale.data = self.k_dequant_scale.to(*args, **kwargs) + self.k_dequant_scale.data = self.k_dequant_scale.to(torch.float32) + self.v_dequant_scale.data = self.v_dequant_scale.to(*args, **kwargs) + self.v_dequant_scale.data = self.v_dequant_scale.to(torch.float32) + return self + + def forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + output_parallel = self.linear_method.apply_weights( + self.linear_weights, input_, bias, scale=None, dequant_scale=1.0) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class QuantRowParallelLinear(RowParallelLinear): + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + linear_method: Optional[LinearMethodBase] = None, + dequant_scale: float = 1.0, + ): + super().__init__(input_size,output_size,bias,input_is_parallel, + skip_bias_add,params_dtype,reduce_results,linear_method) + self.register_parameter( + "dequant_scale", + torch.nn.Parameter( + torch.tensor(dequant_scale,dtype=torch.float32,requires_grad=False)) + ) + + def _apply(self, fn): + super()._apply(fn) + self.dequant_scale.data = self.dequant_scale.cpu() + return self + + def to(self, *args, **kwargs): + super().to(*args, **kwargs) + self.dequant_scale.data = self.dequant_scale.to(*args, **kwargs) + self.dequant_scale.data = self.dequant_scale.to(torch.float32) + return self + + def forward(self, input_, scale=None): + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.linear_method.apply_weights( + self.linear_weights, input_parallel, self.bias, scale=scale, dequant_scale=self.dequant_scale.item(),is_row=True) + if self.reduce_results and self.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py new file mode 100644 index 0000000..caf7283 --- /dev/null +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -0,0 +1,28 @@ +from typing import Type + +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig +from vllm.model_executor.layers.quantization.marlin import MarlinConfig +from vllm.model_executor.layers.quantization.smoothquant import SmoothQuantConfig + +_QUANTIZATION_CONFIG_REGISTRY = { + "awq": AWQConfig, + "gptq": GPTQConfig, + "squeezellm": SqueezeLLMConfig, + "marlin": MarlinConfig, + "smoothquant": SmoothQuantConfig, +} + + +def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: + if quantization not in _QUANTIZATION_CONFIG_REGISTRY: + raise ValueError(f"Invalid quantization method: {quantization}") + return _QUANTIZATION_CONFIG_REGISTRY[quantization] + + +__all__ = [ + "QuantizationConfig", + "get_quantization_config", +] diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py new file mode 100644 index 0000000..97238fb --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq.py @@ -0,0 +1,170 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + + +class AWQConfig(QuantizationConfig): + """Config class for AWQ. + + Reference: https://arxiv.org/abs/2306.00978 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"AWQ, but got {self.weight_bits} bits.") + self.pack_factor = 32 // self.weight_bits + + def __repr__(self) -> str: + return (f"AWQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point})") + + def get_name(self) -> str: + return "awq" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half] + + def get_min_capability(self) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq + "quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + return cls(weight_bits, group_size, zero_point) + + def get_linear_method(self) -> "AWQLinearMethod": + return AWQLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] + + +class AWQLinearMethod(LinearMethodBase): + """Linear method for AWQ. + + Args: + quant_config: The AWQ quantization config. + """ + + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + qweight = Parameter( + torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }) + qzeros = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qzeros, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }) + scales = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, { + "input_dim": 0, + "output_dim": 1, + }) + return { + "qweight": qweight, + "qzeros": qzeros, + "scales": scales, + } + + def apply_weights(self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = weights["qweight"] + scales = weights["scales"] + qzeros = weights["qzeros"] + pack_factor = self.quant_config.pack_factor + out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) + reshaped_x = x.reshape(-1, x.shape[-1]) + + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, + pack_factor) + # TODO align + """ + # num_tokens >= threshold + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0) + out = torch.matmul(reshaped_x, out) + else: + out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros, + pack_factor) + """ + if bias is not None: + out = out + bias + return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py new file mode 100644 index 0000000..6115e7c --- /dev/null +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -0,0 +1,64 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +import torch + +from vllm.model_executor.layers.linear import LinearMethodBase + + +class QuantizationConfig(ABC): + """Base class for quantization configs.""" + + @abstractmethod + def get_name(self) -> str: + """Name of the quantization method.""" + raise NotImplementedError + + @abstractmethod + def get_supported_act_dtypes(self) -> List[torch.dtype]: + """List of supported activation dtypes.""" + raise NotImplementedError + + @abstractmethod + def get_min_capability(self) -> int: + """Minimum GPU capability to support the quantization method. + + E.g., 70 for Volta, 75 for Turing, 80 for Ampere. + This requirement is due to the custom CUDA kernels used by the + quantization method. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_config_filenames() -> List[str]: + """List of filenames to search for in the model directory.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": + """Create a config class from the model's quantization config.""" + raise NotImplementedError + + @staticmethod + def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + raise ValueError(f"Cannot find any of {keys} in the model's " + "quantization config.") + + @abstractmethod + def get_linear_method(self) -> LinearMethodBase: + """Get the linear method to use for the quantized linear layer.""" + raise NotImplementedError + + @abstractmethod + def get_scaled_act_names(self) -> List[str]: + """Returns the activation function names that should be post-scaled. + + For now, this is only used by AWQ. + """ + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py new file mode 100644 index 0000000..9d4f8da --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -0,0 +1,218 @@ +import enum +from enum import Enum +from typing import Any, Dict, List, Optional +from fractions import Fraction + +import torch +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + + +class GPTQConfig(QuantizationConfig): + """Config class for GPTQ. + + Reference: https://arxiv.org/abs/2210.17323 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.pack_factor = Fraction(32, self.weight_bits) + if self.weight_bits not in [2, 3, 4, 8]: + raise ValueError( + "Currently, only 2/3/4/8-bit weight quantization is supported for " + f"GPTQ, but got {self.weight_bits} bits.") + + def __repr__(self) -> str: + return (f"GPTQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})") + + @classmethod + def get_name(cls) -> str: + return "gptq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + return cls(weight_bits, group_size, desc_act) + + def get_linear_method(self) -> "GPTQLinearMethod": + return GPTQLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class ExllamaState(Enum): + + UNUSED = enum.auto() + UNINITIALIZED = enum.auto() + READY = enum.auto() + + +class GPTQLinearMethod(LinearMethodBase): + """Linear method for GPTQ. + + Args: + quant_config: The GPTQ quantization config. + """ + + def __init__(self, quant_config: GPTQConfig): + self.quant_config = quant_config + + def create_weights( + self, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + del output_size # Unused. + if input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + if output_size_per_partition % self.quant_config.pack_factor.numerator != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + exllama_state = ExllamaState.UNINITIALIZED + scale_and_zero_size = input_size // group_size + scale_and_zero_input_dim = None + if input_size != input_size_per_partition and self.quant_config.group_size != -1: + # For act-order models, we cannot use Exllama for row parallel layer + if self.quant_config.desc_act: + raise NotImplementedError() + exllama_state = ExllamaState.UNUSED + else: + # we need to partition qzeros and scales for exllama kernel + scale_and_zero_size = input_size_per_partition // group_size + scale_and_zero_input_dim = 0 + + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 0, + "pack_factor": self.quant_config.pack_factor, + }) + g_idx = Parameter( + torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + dtype=torch.int32, + ), + requires_grad=False, + ) + # Ignore warning from fused linear layers such as QKVParallelLinear. + set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True}) + qzeros = Parameter( + torch.empty( + scale_and_zero_size, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qzeros, { + "input_dim": scale_and_zero_input_dim, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }) + scales = Parameter( + torch.empty( + scale_and_zero_size, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, { + "input_dim": scale_and_zero_input_dim, + "output_dim": 1, + }) + return { + "qweight": qweight, + "g_idx": g_idx, + "qzeros": qzeros, + "scales": scales, + "exllama_state": exllama_state, + } + + def apply_weights(self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = weights["qweight"] + out_shape = x.shape[:-1] + (qweight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) + # exllama needs to shuffle the weight after the weight is loaded + # here we do the shuffle on first forward pass + if weights["exllama_state"] == ExllamaState.UNINITIALIZED: + if self.quant_config.desc_act: + weights["g_idx"] = torch.argsort(weights["g_idx"]).to( + torch.int) + else: + weights["g_idx"] = None + # TODO align + """ + weights["g_idx"] = torch.empty((1, 1), device="meta") + """ + weights["exllama_state"] = ExllamaState.READY + ops.gptq_shuffle(weights["qweight"], weights["g_idx"], + self.quant_config.weight_bits) + output = ops.gptq_gemm(reshaped_x, weights["qweight"], + weights["qzeros"], weights["scales"], + weights["g_idx"], + weights["exllama_state"] == ExllamaState.READY, + self.quant_config.weight_bits) + if bias is not None: + output = output + bias + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py new file mode 100644 index 0000000..7566d78 --- /dev/null +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -0,0 +1,210 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + + +class MarlinConfig(QuantizationConfig): + """Config class for Marlin. + + Reference: https://github.com/IST-DASLab/marlin/tree/master + """ + + def __init__( + self, + group_size: int, + ) -> None: + # Group size for the quantization. + self.group_size = group_size + if self.group_size != 128 and self.group_size != -1: + raise ValueError( + "Currently, only group size 128 and -1 (channelwise) is supported for " + f"Marlin, but got group_size of {self.group_size}") + + # 4 Bits packed into 32 bit datatype. + self.pack_factor = 32 // 4 + + # Tile size used by marlin kernels. + self.tile_size = 16 + + # Min out_features dim + self.min_n_threads = 64 + + # Min in_features dim + self.min_k_threads = 128 + + # Max parallel problems to solve at once (improves large batch performance) + self.max_parallel = 16 + + # Permutation length used by the marlin kernels. + self.perm_len = 1024 + + def __repr__(self) -> str: + return f"MarlinConfig(group_size={self.group_size}" + + @classmethod + def get_name(cls) -> str: + return "marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + # Need to figure it out + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": + group_size = cls.get_from_keys(config, ["group_size"]) + return cls(group_size) + + def get_linear_method(self) -> "MarlinLinearMethod": + return MarlinLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class MarlinLinearMethod(LinearMethodBase): + """Linear method for Marlin. + + Args: + quant_config: The Marlin quantization config. + """ + + def __init__(self, quant_config: MarlinConfig): + self.quant_config = quant_config + + def create_weights( + self, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + del output_size # Unused. + + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + if output_size_per_partition % self.quant_config.min_n_threads != 0: + raise ValueError( + f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by min_n_threads = {self.quant_config.min_n_threads}." + ) + if output_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + f"Weight output_size_per_partition = {output_size_per_partition} is not divisible by pack_factor = {self.quant_config.pack_factor}." + ) + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_k_threads != 0: + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition} is not divisible by min_k_threads = {self.quant_config.min_k_threads}." + ) + if self.quant_config.group_size != -1 and input_size_per_partition % self.quant_config.group_size != 0: + raise ValueError( + f"Weight input_size_per_partition = f{input_size_per_partition} is not divisible by group_size = {self.quant_config.group_size}." + ) + + # Check that we have at least 4 tiles horizontally in the shard + num_tiles_per_perm = self.quant_config.perm_len // ( + self.quant_config.tile_size**2) + if output_size_per_partition % num_tiles_per_perm != 0: + raise ValueError( + "Each permutation group must reside on the same gpu") + + # Quantized 4Bit weights packed into Int32. + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.tile_size, + output_size_per_partition * self.quant_config.tile_size // + self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, + { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + "marlin_tile_size": self.quant_config.tile_size, + }, + ) + + # Determine if channelwise or not + input_groups = 1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size + + scales = Parameter( + torch.empty( + input_groups, + output_size_per_partition, + device="cuda", + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, + { + "input_dim": None if input_groups == 1 else 0, + "output_dim": 1, + }, + ) + + # Allocate workspace (Used for internal locking mechanism) + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_n_threads) * self.quant_config.max_parallel + workspace = Parameter(torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + requires_grad=False) + + return { + "B": qweight, + "s": scales, + "workspace": workspace, + } + + def apply_weights( + self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qweight = weights["B"] + scales = weights["s"] + workspace = weights["workspace"] + + x_2d = x.view(-1, x.shape[-1]) + + size_m = x_2d.shape[0] + size_k = x_2d.shape[1] + size_n = scales.shape[1] + + output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m, + size_n, size_k) + + output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], )) + + if bias is not None: + output.add_(bias) # In-place add + + return output diff --git a/vllm/model_executor/layers/quantization/smoothquant.py b/vllm/model_executor/layers/quantization/smoothquant.py new file mode 100644 index 0000000..1ea954c --- /dev/null +++ b/vllm/model_executor/layers/quantization/smoothquant.py @@ -0,0 +1,111 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_world_size + + +class SmoothQuantConfig(QuantizationConfig): + """Config class for SmoothQuant + Reference: https://github.com/mit-han-lab/smoothquant + """ + + def __init__( + self, + weight_bits: int, + quant_type: str = "tensor" + ) -> None: + self.weight_bits = weight_bits + self.quant_type = quant_type + + if self.weight_bits != 8: + raise ValueError( + "Currently, only w8a8 quantization is supported for " + f"SmoothQuant, but got {self.weight_bits} bits.") + if self.quant_type != "tensor": + raise ValueError( + "Currently, only tensor wise quantization is supported for " + f"SmoothQuant, but got {self.quant_type} type quantization.") + + def __repr__(self) -> str: + return (f"SmoothQuantConfig(weight_bits={self.weight_bits}, " + f"quant_type={self.quant_type})") + + def get_name(self) -> str: + return "smoothquant" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half, torch.float] + + def get_min_capability(self) -> int: + return 70 + + @staticmethod + def get_config_filenames() -> List[str]: + """List of filenames to search for in the model directory.""" + return [ + "quant_config.json", + "quantize_config.json", + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SmoothQuantConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + quant_type = cls.get_from_keys(config, ["quant_type", "q_type"]) + return cls(weight_bits, quant_type) + + def get_linear_method(self) -> "SmoothLinearMethod": + return SmoothLinearMethod(world_size=get_tensor_model_parallel_world_size()) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class SmoothLinearMethod(LinearMethodBase): + def __init__(self, world_size, *args, **kwargs): + super().__init__(*args, **kwargs) + self.apply_dequant_after_row = world_size > 1 + self.dtpye = None + + def create_weights( + self, + input_size_per_partition: int, + output_size_per_partition: int, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + ) -> Dict[str, Any]: + weight = Parameter(torch.empty(output_size_per_partition, + input_size_per_partition, + dtype=torch.int8), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + self.dtpye = params_dtype + return {"weight": weight} + + def apply_weights( + self, + weights: Dict[str, torch.Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor], + scale: Optional[torch.Tensor] = None, + dequant_scale: float = 1.0, + is_row: bool = False, + ) -> torch.Tensor: + x_shape = x.shape + x = x.view(-1, x_shape[-1]) + weight = weights["weight"] + y = torch.empty((x.shape[0], weight.shape[0]),dtype=torch.int32,device=x.device) + ops.linear_a8_w8_o32_(x, weight, y) + y = y.view(*x_shape[:-1], -1) + if is_row and self.apply_dequant_after_row: + # when tp > 1, duquant first(To improve accuracy?) + out = torch.empty_like(y, dtype=self.dtpye) + ops.dequant(out, y, scale, dequant_scale) + y = out + return y diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py new file mode 100644 index 0000000..9244e88 --- /dev/null +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -0,0 +1,129 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm._C import ops +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig +from vllm.utils import is_hip + + +class SqueezeLLMConfig(QuantizationConfig): + """Config class for SqueezeLLM. + + Reference: https://arxiv.org/pdf/2306.07629 + """ + + def __init__( + self, + weight_bits: int, + ) -> None: + self.weight_bits = weight_bits + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"SqueezeLLM, but got {self.weight_bits} bits.") + + self.pack_factor = 32 // self.weight_bits + + def __repr__(self) -> str: + return f"SqueezeLLMConfig(weight_bits={self.weight_bits})" + + def get_name(self) -> str: + return "squeezellm" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half] + + def get_min_capability(self) -> int: + return 70 + + @staticmethod + def get_config_filenames() -> List[str]: + return ["quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": + weight_bits = cls.get_from_keys(config, ["wbits"]) + return cls(weight_bits) + + def get_linear_method(self) -> "SqueezeLLMLinearMethod": + return SqueezeLLMLinearMethod(self) + + def get_scaled_act_names(self) -> List[str]: + return [] + + +class SqueezeLLMLinearMethod(LinearMethodBase): + """Linear method for SqueezeLLM. + + Args: + quant_config: The SqueezeLLM quantization config. + """ + + def __init__(self, quant_config: SqueezeLLMConfig): + self.quant_config = quant_config + + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + if input_size_per_partition % self.quant_config.pack_factor != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 0, + "pack_factor": self.quant_config.pack_factor, + }) + lookup_table = Parameter( + torch.empty( + output_size, + self.quant_config.weight_bits**2, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(lookup_table, { + "output_dim": 0, + }) + return { + "qweight": qweight, + "lookup_table": lookup_table, + } + + def apply_weights(self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = weights["qweight"] + lookup_table = weights["lookup_table"] + out_shape = x.shape[:-1] + (qweight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) + if is_hip(): + out_f = torch.zeros(out_shape, dtype=torch.float) + ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table) + out = out_f.to(dtype=torch.float16) + else: + # NOTE: The output tensor should be zero-initialized. + out = torch.zeros(out_shape, dtype=torch.float16) + ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) + + if bias is not None: + out = out + bias + return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py new file mode 100644 index 0000000..3e1cfc7 --- /dev/null +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -0,0 +1,392 @@ +from typing import Tuple, Optional +from functools import cached_property + +import torch +import torch.nn as nn +import torch.jit + + +class RejectionSampler(nn.Module): + """Apply modified rejection sampling as described in "Accelerating Large + Language Model Decoding with Speculative Sampling" + https://arxiv.org/pdf/2302.01318.pdf. + """ + + def __init__(self, strict_mode: bool = False): + """Create a rejection sampler. + + Args: + strict_mode: Whether or not to perform shape/device/dtype checks + during sampling. This catches correctness issues but adds + nontrivial latency. + """ + super().__init__() + self.probs_dtype = torch.float32 + self.token_id_dtype = torch.int64 + self._strict_mode = strict_mode + + # NOTE: A "bonus token" is accepted iff all proposal tokens are + # accepted. There is always only one possible bonus token. We store this + # value in a variable for readability. + self._num_bonus_tokens = 1 + + self.num_accepted_tokens: Optional[torch.Tensor] = None + self.num_emitted_tokens: Optional[torch.Tensor] = None + self.num_draft_tokens: int = 0 + + def init_gpu_tensors(self, rank: int) -> None: + assert self.num_accepted_tokens is None + device = f"cuda:{rank}" + self.num_accepted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + self.num_emitted_tokens = torch.tensor(0, + dtype=torch.long, + device=device) + + def forward( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> torch.Tensor: + """Sample token ids using rejection sampling. This accepts or rejects + tokens proposed by the draft model using the probability of each token + according to the draft and target models. + + In the worst case where all draft tokens are rejected, it is guaranteed + one correct token will be emitted. + + In the case where all draft tokens are accepted, a bonus token will be + accepted as its cheap to have the target model score this speculative + sequence. + + Args: + target_probs: The probability distribution over token ids given + context according to the target model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + bonus_token_ids: The "bonus" token ids that are accepted iff all + speculative tokens in a sequence are accepted. + shape = [batch_size, num_bonus_tokens] + + draft_probs: The probability distribution over token ids given + context according to the draft model. + shape = [batch_size, num_speculative_tokens, vocab_size] + + draft_token_ids: The token ids that were sampled from the draft + probabilities. + shape = [batch_size, num_speculative_tokens] + + Returns: + output_token_ids: The token ids sampled via rejection sampling, + or -1 if unable to sample a token because the previous token + was rejected. + shape = [batch_size, num_speculative_tokens + num_bonus_tokens] + """ + # Only perform shape/dtype/device checking in strict mode, as it adds + # overhead. + if self._strict_mode: + self._raise_if_incorrect_shape(target_probs, bonus_token_ids, + draft_probs, draft_token_ids) + self._raise_if_incorrect_dtype(target_probs, bonus_token_ids, + draft_probs, draft_token_ids) + self._raise_if_inconsistent_device(target_probs, bonus_token_ids, + draft_probs, draft_token_ids) + self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], + bonus_token_ids, + draft_token_ids) + + accepted, recovered_token_ids = self._batch_modified_rejection_sampling( + target_probs, + draft_probs, + draft_token_ids, + ) + + output_token_ids = self._create_output( + accepted, + recovered_token_ids, + draft_token_ids, + bonus_token_ids, + ) + return output_token_ids + + def _batch_modified_rejection_sampling( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Perform modified rejection sampling on each sequence. + + Returns: + A tuple of two tensors: + 0: A bool tensor of which tokens in each sequence is accepted. + shape = [batch_size, k] + 1: Token ids sampled from a recovered distribution, to be used + when a token is rejected. + shape = [batch_size, k] + """ + + batch_size, k, vocab_size = draft_probs.shape + + # shape [batch_size, k] + accepted = self._get_accepted(target_probs, draft_probs, + draft_token_ids) + + recovered_probs = self._get_recovered_probs( + target_probs, draft_probs).reshape(batch_size * k, vocab_size) + + recovered_token_ids = _multinomial(recovered_probs, + num_samples=1).reshape( + batch_size, k) + return accepted, recovered_token_ids + + def _get_accepted( + self, + target_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_probs: torch.Tensor, # [batch_size, k, vocab_size] + draft_token_ids: torch.Tensor, # [batch_size, k] + ) -> torch.Tensor: + r"""Create bool matrix over the proposed draft tokens. If + True, then a token can be accepted, else it should be + rejected. + + Given :math:`q(\hat{x}_{n+1}|x_1, \dots, x_n)`, the probability of + :math:`\hat{x}_{n+1}` given context :math:`x_1, \dots, x_n` according + to the target model, and :math:`p(\hat{x}_{n+1}|x_1, \dots, x_n)`, the + same conditional probability according to the draft model, the token + is accepted with probability: + + .. math:: + \min\left(1, \frac{q(\hat{x}_{n+1}|x_1, \dots, x_n)} + {p(\hat{x}_{n+1}|x_1, \dots, x_n)}\right) + + This implementation does not apply causality. When using the output, + if a token is rejected, subsequent tokens should not be used. + + Returns a bool tensor of shape [batch_size, k] specifying which tokens + are accepted. + """ + batch_size, k, _ = draft_probs.shape + batch_indices = torch.arange(batch_size, + device=target_probs.device)[:, None] + probs_indicies = torch.arange(k, device=target_probs.device) + + # shape [batch_size, k] + selected_draft_probs = draft_probs[batch_indices, probs_indicies, + draft_token_ids] + + # shape [batch_size, k] + selected_target_probs = target_probs[batch_indices, probs_indicies, + draft_token_ids] + + uniform_rand = torch.rand(batch_size, + k, + dtype=self.probs_dtype, + device=target_probs.device) + capped_ratio = torch.minimum( + selected_target_probs / selected_draft_probs, + torch.full((1, ), 1, device=target_probs.device)) + accepted = uniform_rand < capped_ratio + + return accepted + + def _get_recovered_probs( + self, + target_probs: torch.Tensor, # [k, vocab_size] + draft_probs: torch.Tensor, # [k, vocab_size] + ) -> torch.Tensor: + r"""Create a probability distribution for each proposed token which can + be sampled if the proposed token is rejected. + + When this routine is applied sequentially, the true distribution of the + target model is recovered (within hardware numerics). + + The probability distribution used in this rejection case is constructed + as follows. Given :math:`q(x|x_1, \dots, x_n)`, the probability of + :math:`x` given context :math:`x_1, \dots, x_n` according to the target + model and :math:`p(x|x_1, \dots, x_n)`, the same conditional probability + according to the draft model: + + .. math:: + x_{n+1} \sim (q(x|x_1, \dots, x_n) - p(x|x_1, \dots, x_n))_+ + + where :math:`(f(x))_+` is defined as: + + .. math:: + (f(x))_+ = \frac{\max(0, f(x))}{\sum_x \max(0, f(x))} + + See https://github.com/vllm-project/vllm/pull/2336 for a visualization + of the draft, target, and recovered probability distributions. + + Returns a tensor of shape [batch_size, k, vocab_size]. + + Note: This batches operations on GPU and thus constructs the recovered + distribution for all tokens, even if they are accepted. This causes + division-by-zero errors, so we use self._smallest_positive_value to + avoid that. This introduces some drift to the distribution. + """ + _, k, _ = draft_probs.shape + + # shape [batch_size, k, vocab_size] + difference = target_probs - draft_probs + + # TODO(cade): Can we use logprobs instead of probs, and avoid the + # division-by-zero errors without introducing distribution drift? + + # shape [batch_size, k, vocab_size] + f = torch.clamp(difference, min=self._smallest_positive_value) + + # shape [batch_size, k, vocab_size] + recovered_probs = f / torch.sum(f, dim=-1).reshape(-1, k, 1) + + return recovered_probs + + @cached_property + def _smallest_positive_value(self) -> float: + """Return the smallest positive value representable by the probs dtype. + This value is used when constructing a distribution from which to sample + recovered tokens in the first rejection case. + + See _get_recovered_probs for more details + + Note that this isn't actually the smallest positive value representable + by float32, but the smallest positive normal value. + See https://en.wikipedia.org/wiki/Subnormal_number for more information. + """ + return torch.finfo(self.probs_dtype).tiny + + def _create_output( + self, + accepted: torch.Tensor, # [batch_size, k] + recovered_token_ids: torch.Tensor, # [batch_size, k] + draft_token_ids: torch.Tensor, # [batch_size, k] + bonus_token_ids: torch.Tensor, # [batch_size] + ) -> torch.Tensor: + """Format output. Returns a matrix of token ids. When + a token is rejected via rejection sampling, all subsequent + token ids are set to -1 for the sequence. + + shape = [batch_size, k + num_bonus_tokens] + """ + bonus_token_ids = bonus_token_ids.squeeze() + batch_size, k = recovered_token_ids.shape + + # Determine the index of the first False value for each row. + limits = (accepted == 0).max(1).indices + limits[~(accepted == 0).any(1)] = k + + # Create masks using the indices. + indices = torch.arange(k, device=accepted.device).unsqueeze(0) + accepted_mask = indices < limits.unsqueeze(1) + after_false_mask = indices == limits.unsqueeze(1) + + # Create an extended output tensor + output_with_bonus_tokens = -torch.ones( + (batch_size, k + self._num_bonus_tokens), + dtype=self.token_id_dtype, + device=accepted.device) + output = output_with_bonus_tokens[:, :k] + + # Fill in the first k columns of the output tensor using masks and data + # tensors. + output[:, :k] = torch.where(accepted_mask, draft_token_ids, + -torch.ones_like(draft_token_ids)) + + # Fill the last column. + # We check output directly as accepted may have True values inconsistent + # with causal acceptance. + output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1, + bonus_token_ids, -1) + + # Fill the recovered token ids. + output.mul_(~after_false_mask).add_( + recovered_token_ids.mul(after_false_mask)) + + self.num_accepted_tokens += accepted.sum() + self.num_emitted_tokens += (output_with_bonus_tokens != -1).sum() + self.num_draft_tokens += batch_size * k + + return output_with_bonus_tokens + + def _raise_if_incorrect_shape( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + (target_batch_size, num_target_probs, + target_vocab_size) = target_probs.shape + bonus_batch_size, num_bonus_tokens = bonus_token_ids.shape + draft_batch_size, num_draft_probs, draft_vocab_size = draft_probs.shape + draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape + + assert draft_batch_size == target_batch_size + assert num_draft_probs == num_target_probs + assert (draft_vocab_size == target_vocab_size + ), f"{draft_vocab_size=} {target_vocab_size=}" + + assert draft_token_ids_batch_size == draft_batch_size + assert num_draft_token_ids == num_draft_probs + + assert bonus_batch_size == target_batch_size + assert num_bonus_tokens == self._num_bonus_tokens + + def _raise_if_incorrect_dtype( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + assert all(probs.dtype == self.probs_dtype + for probs in [target_probs, draft_probs]) + assert all(token_ids.dtype == self.token_id_dtype + for token_ids in [bonus_token_ids, draft_token_ids]) + + def _raise_if_inconsistent_device( + self, + target_probs: torch.Tensor, + bonus_token_ids: torch.Tensor, + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + devices = [ + t.device for t in + [target_probs, bonus_token_ids, draft_probs, draft_token_ids] + ] + assert all([devices[0] == device for device in devices]) + + def _raise_if_out_of_bounds_vocab( + self, + vocab_size: int, + bonus_token_ids: torch.Tensor, + draft_token_ids: torch.Tensor, + ) -> None: + assert torch.all(bonus_token_ids < vocab_size) + assert torch.all(bonus_token_ids >= 0) + assert torch.all(draft_token_ids < vocab_size) + assert torch.all(draft_token_ids >= 0) + + +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead that skips the sync. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +@torch.jit.script +def _multinomial( + probs: torch.Tensor, + num_samples: int, +) -> torch.Tensor: + if num_samples > 1: + # This is equivalent to torch.repeat_interleaved (which also + # forces a GPU<->CPU sync). + probs = probs[:, None, :].expand(probs.shape[0], num_samples, + probs.shape[1]).contiguous().view( + -1, probs.shape[1]) + q = torch.empty_like(probs).exponential_(1.0) + return probs.div_(q).argmax(dim=1).view(-1, num_samples) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py new file mode 100644 index 0000000..933b24e --- /dev/null +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -0,0 +1,562 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rotary Positional Embeddings.""" +import math +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from vllm._C import ops + + +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + + cache = self._compute_cos_sin_cache() + cache = cache.to(torch.get_default_dtype()) + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): The HF implementation uses `torch.arange(...).float()`. + # However, we use `torch.arange(..., dtype=torch.float)` instead to + # avoid numerical issues with large base values (e.g., 10000000). + # This may cause a slight numerical difference between the HF + # implementation and ours. + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def _forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + query = query.flatten(-2) + key = key.flatten(-2) + return query, key + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # ops.rotary_embedding() is an in-place operation that + # updates the query and key tensors. + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + t = torch.arange(max_len, dtype=torch.float) + t = t / self.scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + base = self.base * ( + (self.scaling_factor * max_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.rotary_dim / + (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim(num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> float: + return (dim * math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * + math.log(base)) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range(low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> int: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, + max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(low: float, high: float, dim: int, + dtype: torch.dtype) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: float = 32, + beta_slow: float = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + _yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + return cache + + +_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, +) -> RotaryEmbedding: + key = (head_size, rotary_dim, max_position, base, is_neox_style, + tuple(rope_scaling.items()) if rope_scaling is not None else None) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if rope_scaling is None: + rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style) + else: + scaling_type = rope_scaling[ + "type"] if "type" in rope_scaling else rope_scaling["rope_type"] + scaling_factor = rope_scaling["factor"] + if scaling_type == "llama3": + dtype = torch.get_default_dtype() + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + scaling_factor, low_freq_factor, + high_freq_factor, + original_max_position) + elif scaling_type == "linear": + rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor) + elif scaling_type == "dynamic": + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + scaling_factor) + elif scaling_type == "yarn": + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow") + } + rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, + original_max_position, + base, is_neox_style, + scaling_factor, + **extra_kwargs) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb + + +# ↓ add for smoothquant +class DequantRotaryEmbedding(RotaryEmbedding): + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + q_dequant_scale: float, + k_dequant_scale: float, + v_dequant_scale: float + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # pos_encoding_ops.rotary_embedding() is an in-place operation that + # updates the query and key tensors. + query_dequant = torch.empty_like(query, dtype=self.cos_sin_cache.dtype) + key_dequant = torch.empty_like(key, dtype=self.cos_sin_cache.dtype) + value_dequant = torch.empty_like(value, dtype=self.cos_sin_cache.dtype) + + ops.dequant(value_dequant, value, None, v_dequant_scale) + ops.dequant_rotary_embedding( + positions, + query, + key, + self.head_size, + self.cos_sin_cache, + query_dequant, + key_dequant, + q_dequant_scale, + k_dequant_scale, + self.is_neox_style, + ) + return query_dequant, key_dequant, value_dequant + +class Llama3RotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + low_freq_wavelen = self.orig_max_position / self.low_freq_factor + high_freq_wavelen = self.orig_max_position / self.high_freq_factor + + wave_len = 2 * math.pi / inv_freqs + if self.low_freq_factor != self.high_freq_factor: + smooth = (self.orig_max_position / wave_len - self.low_freq_factor + ) / (self.high_freq_factor - self.low_freq_factor) + else: + smooth = 0 + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freqs, + torch.where( + wave_len > low_freq_wavelen, + inv_freqs / self.scaling_factor, + (1 - smooth) * inv_freqs / self.scaling_factor + + smooth * inv_freqs, + ), + ) + return new_freqs + +class DequantLinearScalingRotaryEmbedding(LinearScalingRotaryEmbedding, + DequantRotaryEmbedding): + + def __init__(self, *args, **kwargs): + LinearScalingRotaryEmbedding.__init__(self, *args, **kwargs) + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dequant_scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return DequantRotaryEmbedding.forward(self, positions, query, key, + value, dequant_scale) + +class DequantDynamicNTKScalingRotaryEmbedding(DynamicNTKScalingRotaryEmbedding, + DequantRotaryEmbedding): + + def __init__(self, *args, **kwargs): + DynamicNTKScalingRotaryEmbedding.__init__(self, *args, **kwargs) + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dequant_scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return DequantRotaryEmbedding.forward(self, positions, query, key, + value, dequant_scale) + +class DequantYaRNScalingRotaryEmbedding(YaRNScalingRotaryEmbedding, + DequantRotaryEmbedding): + + def __init__(self, *args, **kwargs): + YaRNScalingRotaryEmbedding.__init__(self, *args, **kwargs) + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dequant_scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return DequantRotaryEmbedding.forward(self, positions, query, key, + value, dequant_scale) + +_DEQUANT_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} + + +def get_dequant_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, +) -> RotaryEmbedding: + key = (head_size, rotary_dim, max_position, base, is_neox_style, + tuple(rope_scaling.items()) if rope_scaling is not None else None) + if key in _DEQUANT_ROPE_DICT: + return _DEQUANT_ROPE_DICT[key] + + if rope_scaling is None: + rotary_emb = DequantRotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style) + else: + scaling_type = rope_scaling["type"] + scaling_factor = rope_scaling["factor"] + if scaling_type == "linear": + rotary_emb = DequantLinearScalingRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor) + elif scaling_type == "dynamic": + rotary_emb = DequantDynamicNTKScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + scaling_factor) + elif scaling_type == "yarn": + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow") + } + rotary_emb = DequantYaRNScalingRotaryEmbedding(head_size, rotary_dim, + original_max_position, + base, is_neox_style, + scaling_factor, + **extra_kwargs) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _DEQUANT_ROPE_DICT[key] = rotary_emb + return rotary_emb \ No newline at end of file diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py new file mode 100644 index 0000000..48db579 --- /dev/null +++ b/vllm/model_executor/layers/sampler.py @@ -0,0 +1,598 @@ +"""A layer that samples the next tokens from the model's outputs.""" +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_gather,tensor_model_parallel_all_gather) +from vllm.model_executor.sampling_metadata import SamplingMetadata, SamplingTensors +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import (PromptLogprobs, SampleLogprobs, SamplerOutput, + SequenceData, SequenceGroupOutput, SequenceOutput) +from vllm.utils import is_neuron +import ixformer.functions as ixf_F + + +class Sampler(nn.Module): + """Samples the next tokens from the model's outputs. + + This layer does the following: + 1. Discard the hidden states that are not used for sampling (i.e., all + tokens except the final one in each prompt). + 2. Compute the logits for the next tokens. + 3. Apply presence, frequency and repetition penalties. + 4. Apply temperature scaling. + 5. Apply top-p and top-k truncation. + 6. Sample the next tokens. + Here, each sequence group within the batch can have different sampling + parameters (e.g., sampling method, temperature, top-p, top-k, etc.). + """ + + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None) -> None: + super().__init__() + self.vocab_size = vocab_size + # Transformers-neuronx generate outputs as logits directly. + self.logits_as_hidden_states = is_neuron() + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + + def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor], + logits_scale = None) -> torch.Tensor: + # Get the logits for the next tokens. + if logits_scale is None: + logits = ixf_F.linear(hidden_states, embedding) + else: + logits = ixf_F.linear(hidden_states / logits_scale, embedding) + # TODO align + """ + logits = torch.matmul(hidden_states, embedding.t()) + """ + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # TODO align + """ + logits = tensor_model_parallel_gather(logits) + """ + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + + def forward( + self, + embedding: torch.Tensor, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + embedding_bias: Optional[torch.Tensor] = None, + logits_scale = None, + ) -> Optional[SamplerOutput]: + # Get the hidden states that we use for sampling. + if self.logits_as_hidden_states: + logits = hidden_states + else: + hidden_states = _prune_hidden_states(hidden_states, + sampling_metadata) + + # Get the logits for the next tokens. + logits = self._get_logits(hidden_states, embedding, embedding_bias, logits_scale) + + # Only perform sampling in the driver worker. + # Note: `_get_logits` is still distributed across TP workers because + # the `embedding` weight is distributed across TP workers. + # TODO(zhuohan): Change the get_logits part to a separate stage. + if not sampling_metadata.perform_sampling: + return None + + assert logits is not None + _, vocab_size = logits.shape + + # Apply logits processors (if any). + logits = _apply_logits_processors(logits, sampling_metadata) + + # Prepare sampling tensors with pinned memory to avoid blocking. + (sampling_tensors, do_penalties, do_top_p_top_k, + do_min_p) = SamplingTensors.from_sampling_metadata( + sampling_metadata, vocab_size, logits.device, logits.dtype) + + # Apply presence and frequency penalties. + if do_penalties: + logits = _apply_penalties(logits, sampling_tensors.prompt_tokens, + sampling_tensors.output_tokens, + sampling_tensors.presence_penalties, + sampling_tensors.frequency_penalties, + sampling_tensors.repetition_penalties) + + # Apply temperature scaling. + # Use in-place division to avoid creating a new tensor. + logits.div_(sampling_tensors.temperatures.unsqueeze_(dim=1)) + + if do_top_p_top_k: + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + + if do_min_p: + logits = _apply_min_p(logits, sampling_tensors.min_ps) + + # We use float32 for probabilities and log probabilities. + # Compute the probabilities. + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # Compute the log probabilities. + # Use log_softmax to ensure numerical stability. + logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + + # Sample the next tokens. + sample_results = _sample(probs, logprobs, sampling_metadata) + # Get the logprobs query results. + prompt_logprobs, sample_logprobs = _get_logprobs( + logprobs, sampling_metadata, sample_results) + return _build_sampler_output(sample_results, sampling_metadata, + prompt_logprobs, sample_logprobs) + + +def _prune_hidden_states( + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + return hidden_states.index_select(0, + sampling_metadata.selected_token_indices) + + +def _get_bin_counts_and_mask( + tokens: torch.Tensor, + vocab_size: int, + num_seqs: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + # Compute the bin counts for the tokens. + # vocab_size + 1 for padding. + bin_counts = torch.zeros((num_seqs, vocab_size + 1), + dtype=torch.long, + device=tokens.device) + bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens)) + bin_counts = bin_counts[:, :vocab_size] + mask = bin_counts > 0 + + return bin_counts, mask + + +def _apply_logits_processors( + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> torch.Tensor: + logits_row_idx = 0 + found_logits_processors = False + for seq_ids, sampling_params in sampling_metadata.seq_groups: + logits_processors = sampling_params.logits_processors + if logits_processors: + found_logits_processors = True + for seq_id in seq_ids: + logits_row = logits[logits_row_idx] + token_ids = sampling_metadata.seq_data[seq_id].output_token_ids + for logits_processor in logits_processors: + logits_row = logits_processor(token_ids, logits_row) + logits[logits_row_idx] = logits_row + logits_row_idx += 1 + else: + logits_row_idx += len(seq_ids) + if found_logits_processors: + assert logits_row_idx == logits.shape[0] + return logits + + +def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor, + output_tokens_tensor: torch.Tensor, + presence_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + repetition_penalties: torch.Tensor) -> torch.Tensor: + num_seqs, vocab_size = logits.shape + _, prompt_mask = _get_bin_counts_and_mask(prompt_tokens_tensor, vocab_size, + num_seqs) + output_bin_counts, output_mask = _get_bin_counts_and_mask( + output_tokens_tensor, vocab_size, num_seqs) + + repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size) + repetition_penalties[~(prompt_mask | output_mask)] = 1.0 + logits = torch.where(logits > 0, logits / repetition_penalties, + logits * repetition_penalties) + + # We follow the definition in OpenAI API. + # Refer to https://platform.openai.com/docs/api-reference/parameter-details + logits -= frequency_penalties.unsqueeze_(dim=1) * output_bin_counts + logits -= presence_penalties.unsqueeze_(dim=1) * output_mask + return logits + + +def _apply_top_k_top_p( + logits: torch.Tensor, + p: torch.Tensor, + k: torch.Tensor, +) -> torch.Tensor: + logits_sort, logits_idx = logits.sort(dim=-1, descending=False) + + # Apply top-k. + top_k_mask = logits_sort.size(1) - k.to(torch.long) + # Get all the top_k values. + top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) + top_k_mask = logits_sort < top_k_mask + logits_sort.masked_fill_(top_k_mask, -float("inf")) + + # Apply top-p. + probs_sort = logits_sort.softmax(dim=-1) + probs_sum = probs_sort.cumsum(dim=-1) + top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) + # at least one + top_p_mask[:, -1] = False + logits_sort.masked_fill_(top_p_mask, -float("inf")) + + # Re-sort the probabilities. + src = torch.arange(logits_idx.shape[-1], + device=logits_idx.device).expand_as(logits_idx) + logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1, + index=logits_idx, + src=src) + logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv) + return logits + + +def _apply_min_p( + logits: torch.Tensor, + min_p: torch.Tensor, +) -> torch.Tensor: + """ + Adapted from + https://github.com/oobabooga/text-generation-webui/blob/3146124ec01f02c8fb1650a6517cf1b60b537aaf/modules/sampler_hijack.py#L16C17-L16C17 + """ + probs = torch.softmax(logits, dim=-1) + top_probs, _ = probs.max(dim=-1, keepdim=True) + scaled_min_p = min_p.unsqueeze_(dim=1) * top_probs + tokens_to_remove = probs < scaled_min_p + logits = logits.masked_fill_(tokens_to_remove, -float("inf")) + + return logits + + +def _greedy_sample( + selected_seq_groups: List[Tuple[List[int], SamplingParams]], + samples: torch.Tensor, +) -> List[Tuple[List[int], List[int]]]: + samples = samples.tolist() + sample_idx = 0 + results = [] + for seq_group in selected_seq_groups: + seq_ids, _ = seq_group + num_parent_seqs = len(seq_ids) + assert num_parent_seqs == 1, ( + "Greedy sampling should have only one seq.") + parent_ids = list(range(num_parent_seqs)) + next_token_ids = [samples[sample_idx]] + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + return results + + +def _random_sample( + selected_seq_groups: List[Tuple[List[int], SamplingParams]], + is_prompts: List[bool], + random_samples: torch.Tensor, +) -> List[Tuple[List[int], List[int]]]: + # Find the maximum best_of value of the prompt phase requests. + random_samples = random_samples.cpu() + sample_idx = 0 + results = [] + for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): + seq_ids, sampling_params = seq_group + num_parent_seqs = len(seq_ids) + if is_prompt: + # Prompt phase. + parent_ids = [0] * sampling_params.best_of + next_token_ids = random_samples[ + sample_idx, :sampling_params.best_of].tolist() + else: + # Generation phase. + parent_ids = list(range(num_parent_seqs)) + next_token_ids = random_samples[sample_idx:sample_idx + + num_parent_seqs, 0].tolist() + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + return results + + +def _beam_search_sample( + selected_seq_groups: List[Tuple[List[int], SamplingParams]], + is_prompts: List[bool], + seq_data: Dict[int, SequenceData], + logprobs: torch.Tensor, +) -> List[Tuple[List[int], List[int]]]: + # We sample 2 * beam_width candidates to make sure that with high + # probability we can get `beam_width` candidates in addition to + # the finished sequences for the next iteration. See + # https://github.com/tensorflow/tensor2tensor/blob/bafdc1b67730430d38d6ab802cbd51f9d053ba2e/tensor2tensor/utils/beam_search.py#L557-L563 + # for details. See also HF reference: + # https://github.com/huggingface/transformers/blob/a4dd53d88e4852f023332d284ff07a01afcd5681/src/transformers/generation/utils.py#L3063-L3065 + # + # NOTE: Beam search is not vectorized, so its speed can be slower than + # other sampling methods. + sample_idx = 0 + results = [] + for seq_group, is_prompt in zip(selected_seq_groups, is_prompts): + seq_ids, sampling_params = seq_group + num_parent_seqs = len(seq_ids) + beam_width = sampling_params.best_of + seq_group_logprobs = logprobs[sample_idx:sample_idx + num_parent_seqs] + if is_prompt: + # Prompt phase. + assert num_parent_seqs == 1, ( + "Prompt input should have only one seq.") + parent_ids = [0] * (2 * beam_width) + _, next_token_ids = torch.topk(seq_group_logprobs[0], + 2 * beam_width) + next_token_ids = next_token_ids.tolist() + else: + # Generation phase. + cumulative_logprobs = [ + seq_data[seq_id].cumulative_logprob for seq_id in seq_ids + ] + cumulative_logprobs = torch.tensor( + cumulative_logprobs, + dtype=torch.float, + device=seq_group_logprobs.device) + seq_group_logprobs = (seq_group_logprobs + + cumulative_logprobs.unsqueeze(dim=1)) + _, topk_ids = torch.topk(seq_group_logprobs.flatten(), + 2 * beam_width) + topk_ids = topk_ids.tolist() + vocab_size = seq_group_logprobs.size(-1) + parent_ids = [i // vocab_size for i in topk_ids] + next_token_ids = [i % vocab_size for i in topk_ids] + results.append((next_token_ids, parent_ids)) + sample_idx += num_parent_seqs + assert sample_idx == logprobs.size(0) + return results + + +# torch.multinomial forces a GPU<->CPU sync. +# Therefore, we use an optimized implementation instead. +# Note that we always sample with replacement. +# probs will be modified in place, but this is fine, as we pass +# in a copy already. +def _multinomial( + probs: torch.Tensor, + num_samples: int, + seq_groups: Optional[List[Tuple[List[int], SamplingParams]]] = None, + generators: Optional[List[torch.Generator]] = None, +) -> torch.Tensor: + if num_samples > 1: + # This is equivalent to torch.repeat_interleaved (which also + # forces a GPU<->CPU sync). + # This allows us to do sampling with replacement by creating + # num_samples copies of each row in the tensor, and then + # batch sampling the resulting tensor. + probs = probs[:, None, :].expand(probs.shape[0], num_samples, + probs.shape[1]).contiguous().view( + -1, probs.shape[1]) + q = torch.empty_like(probs) + if seq_groups is None: + q.exponential_() + else: + sample_idx = 0 + for (seq_ids, _), generator in zip(seq_groups, generators): + next_sample_idx = sample_idx + len(seq_ids) * num_samples + q[sample_idx:next_sample_idx].exponential_(generator=generator) + sample_idx = next_sample_idx + return probs.div_(q).argmax(dim=1).view(-1, num_samples) + + +def _sample( + probs: torch.Tensor, + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, +) -> List[Tuple[List[int], List[int]]]: + categorized_seq_group_ids = {t: [] for t in SamplingType} + categorized_sample_indices = sampling_metadata.categorized_sample_indices + for i, seq_group in enumerate(sampling_metadata.seq_groups): + _, sampling_params = seq_group + sampling_type = sampling_params.sampling_type + categorized_seq_group_ids[sampling_type].append(i) + + sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} + sample_metadata = {} + multinomial_samples = {} + + # Counterintiutively, having two loops here is actually faster. + # The first loop can run without waiting on GPU<->CPU sync. + for sampling_type in SamplingType: + sample_indices = categorized_sample_indices[sampling_type] + num_tokens = len(sample_indices) + if num_tokens == 0: + continue + seq_group_ids = categorized_seq_group_ids[sampling_type] + seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids] + is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids] + sample_metadata[sampling_type] = (seq_group_ids, seq_groups, + is_prompts, sample_indices) + if sampling_type == SamplingType.GREEDY: + greedy_samples = torch.argmax(logprobs[sample_indices.long()], + dim=-1) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + max_best_of = 1 + for seq_group, is_prompt in zip(seq_groups, is_prompts): + if is_prompt: + _, sampling_params = seq_group + max_best_of = max(max_best_of, sampling_params.best_of) + seeded_args = {} if sampling_type == SamplingType.RANDOM else { + "seq_groups": seq_groups, + "generators": sampling_metadata.generators, + } + multinomial_samples[sampling_type] = _multinomial( + probs[sample_indices.long()], max_best_of, **seeded_args) + elif sampling_type == SamplingType.BEAM: + beam_search_logprobs = logprobs[sample_indices] + else: + raise ValueError(f"Unsupported sampling type: {sampling_type}") + + # GPU<->CPU sync happens in the loop below. + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + seq_group_ids, seq_groups, is_prompts, sample_indices = sample_metadata[ + sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + sample_results = _random_sample(seq_groups, is_prompts, + multinomial_samples[sampling_type]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, is_prompts, + sampling_metadata.seq_data, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_ids, sample_results)) + + sample_results = [ + sample_results_dict[i] + for i in range(len(sampling_metadata.seq_groups)) + ] + return sample_results + + +def _get_logprobs( + logprobs: torch.Tensor, + sampling_metadata: SamplingMetadata, + sample_results: List[Tuple[List[int], List[int]]], +) -> Tuple[List[Optional[List[Optional[Dict[int, float]]]]], List[List[Dict[ + int, float]]]]: + # Prepare query indices + batched_logprobs_query_seq_indices: List[int] = [] + batched_logprobs_query_token_indices: List[int] = [] + largest_num_logprobs = 0 + sample_idx = 0 + for i, (seq_group, sample_result) in enumerate( + zip(sampling_metadata.seq_groups, sample_results)): + seq_ids, sampling_params = seq_group + next_token_ids, parent_ids = sample_result + num_parent_seqs = len(seq_ids) + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.prompt_logprobs) + prompt_len = sampling_metadata.prompt_lens[i] + prompt_tokens = sampling_metadata.seq_data[ + seq_ids[0]].prompt_token_ids + batched_logprobs_query_seq_indices.extend( + sample_idx + j for j in range(prompt_len - 1)) + batched_logprobs_query_token_indices.extend( + token_id for token_id in prompt_tokens[1:]) + sample_idx += prompt_len - 1 + batched_logprobs_query_seq_indices.extend( + [sample_idx + parent_id for parent_id in parent_ids]) + batched_logprobs_query_token_indices.extend(next_token_ids) + if sampling_params.logprobs is not None: + largest_num_logprobs = max(largest_num_logprobs, + sampling_params.logprobs) + sample_idx += num_parent_seqs + assert sample_idx == logprobs.size(0) + + # Batched query for logprobs of selected token + batched_logprobs_query_result = logprobs[[ + batched_logprobs_query_seq_indices, + batched_logprobs_query_token_indices + ]] + + # Batched query for logprobs of topk tokens + if largest_num_logprobs > 0: + top_logprobs, top_token_ids = torch.topk(logprobs, + largest_num_logprobs, + dim=-1) + top_logprobs = top_logprobs.cpu() + top_token_ids = top_token_ids.cpu() + else: + top_logprobs, top_token_ids = None, None + + batched_logprobs_query_result = batched_logprobs_query_result.cpu() + + # Gather results + result_prompt_logprobs: List[Optional[PromptLogprobs]] = [] + result_sample_logprobs: List[SampleLogprobs] = [] + sample_idx = 0 + query_result_idx = 0 + for i, (seq_group, sample_result) in enumerate( + zip(sampling_metadata.seq_groups, sample_results)): + seq_ids, sampling_params = seq_group + next_token_ids, parent_ids = sample_result + + # Prompt logprobs + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + num_logprobs = sampling_params.prompt_logprobs + prompt_len = sampling_metadata.prompt_lens[i] + prompt_tokens = sampling_metadata.seq_data[ + seq_ids[0]].prompt_token_ids + group_prompt_logprobs: PromptLogprobs = [None] + for token_id in prompt_tokens[1:]: + prompt_logprobs_dict = { + token_id: + batched_logprobs_query_result[query_result_idx].item() + } + if num_logprobs > 0: + prompt_logprobs_dict.update( + zip(top_token_ids[sample_idx, :num_logprobs].tolist(), + top_logprobs[sample_idx, :num_logprobs].tolist())) + group_prompt_logprobs.append(prompt_logprobs_dict) + sample_idx += 1 + query_result_idx += 1 + result_prompt_logprobs.append(group_prompt_logprobs) + else: + result_prompt_logprobs.append(None) + + # Sample logprobs + num_logprobs = sampling_params.logprobs + if num_logprobs is None: + num_logprobs = 0 + group_sample_logprobs: SampleLogprobs = [] + for next_token_id, parent_id in zip(next_token_ids, parent_ids): + sample_logprobs_dict = { + next_token_id: + batched_logprobs_query_result[query_result_idx].item() + } + query_result_idx += 1 + if num_logprobs > 0: + sample_logprobs_dict.update( + zip( + top_token_ids[sample_idx + + parent_id, :num_logprobs].tolist(), + top_logprobs[sample_idx + + parent_id, :num_logprobs].tolist())) + group_sample_logprobs.append(sample_logprobs_dict) + result_sample_logprobs.append(group_sample_logprobs) + sample_idx += len(seq_ids) + + return result_prompt_logprobs, result_sample_logprobs + + +def _build_sampler_output( + sample_results: List[Tuple[List[int], List[int]]], + sampling_metadata: SamplingMetadata, + prompt_logprobs: List[Optional[PromptLogprobs]], + sample_logprobs: List[SampleLogprobs], +) -> SamplerOutput: + sampler_output = [] + for (seq_group, sample_result, group_prompt_logprobs, + group_sample_logprobs) in zip(sampling_metadata.seq_groups, + sample_results, prompt_logprobs, + sample_logprobs): + seq_ids, _ = seq_group + next_token_ids, parent_ids = sample_result + seq_outputs = [] + for parent_id, next_token_id, logprobs in zip(parent_ids, + next_token_ids, + group_sample_logprobs): + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, logprobs)) + sampler_output.append( + SequenceGroupOutput(seq_outputs, group_prompt_logprobs)) + return sampler_output diff --git a/vllm/model_executor/layers/triton_kernel/__init__.py b/vllm/model_executor/layers/triton_kernel/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/layers/triton_kernel/prefix_prefill.py b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py new file mode 100644 index 0000000..70f0922 --- /dev/null +++ b/vllm/model_executor/layers/triton_kernel/prefix_prefill.py @@ -0,0 +1,745 @@ +# The kernels in this file are adapted from LightLLM's context_attention_fwd: +# https://github.com/ModelTC/lightllm/blob/main/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py + +import torch +import triton +import triton.language as tl + +if triton.__version__ >= "2.1.0": + + @triton.jit + def _fwd_kernel( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @triton.jit + def _fwd_kernel_flash_attn_v2( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + # acc /= l_i[:, None] + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @triton.jit + def _fwd_kernel_alibi( + Q, + K, + V, + K_cache, + V_cache, + B_Loc, + sm_scale, + B_Start_Loc, + B_Seqlen, + B_Ctxlen, + Alibi_slopes, + block_size, + x, + Out, + stride_b_loc_b, + stride_b_loc_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_k_cache_bs, + stride_k_cache_h, + stride_k_cache_d, + stride_k_cache_bl, + stride_k_cache_x, + stride_v_cache_bs, + stride_v_cache_h, + stride_v_cache_d, + stride_v_cache_bl, + num_queries_per_kv: int, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + # attn_bias[] + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // num_queries_per_kv + + # cur_batch_seq_len: the length of prompts + # cur_batch_ctx_len: the length of prefix + # cur_batch_in_all_start_index: the start id of the dim=0 + cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + offs_d[None, :] * stride_qd) + + q = tl.load( + Q + off_q, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + # # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = 0 + for start_n in range(0, cur_batch_ctx_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + bn = tl.load(B_Loc + cur_batch * stride_b_loc_b + + ((start_n + offs_n) // block_size) * stride_b_loc_s, + mask=(start_n + offs_n) < cur_batch_ctx_len, + other=0) + off_k = (bn[None, :] * stride_k_cache_bs + + cur_kv_head * stride_k_cache_h + + (offs_d[:, None] // x) * stride_k_cache_d + + ((start_n + offs_n[None, :]) % block_size) * + stride_k_cache_bl + + (offs_d[:, None] % x) * stride_k_cache_x) + off_v = ( + bn[:, None] * stride_v_cache_bs + + cur_kv_head * stride_v_cache_h + + offs_d[None, :] * stride_v_cache_d + + (start_n + offs_n[:, None]) % block_size * stride_v_cache_bl) + k = tl.load(K_cache + off_k, + mask=(start_n + offs_n[None, :]) < cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, + float("-inf")) + qk *= sm_scale + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(V_cache + off_v, + mask=(start_n + offs_n[:, None]) < cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + off_k = (offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + + offs_d[:, None] * stride_kd) + off_v = (offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + + offs_d[None, :] * stride_vd) + k_ptrs = K + off_k + v_ptrs = V + off_v + + block_mask = tl.where( + block_start_loc < cur_batch_seq_len - cur_batch_ctx_len, 1, 0) + + # init alibi + alibi_slope = tl.load(Alibi_slopes + cur_head) + alibi_start_q = tl.arange( + 0, BLOCK_M) + block_start_loc + cur_batch_ctx_len + alibi_start_k = cur_batch_ctx_len + # # init debugger + # offset_db_q = tl.arange(0, BLOCK_M) + block_start_loc + # offset_db_k = tl.arange(0, BLOCK_N) + # calc q[BLOCK_M, BLOCK_MODEL] mul k[prefix_len: , BLOCK_DMODEL] + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, allow_tf32=False) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, + float("-inf")) + + # load alibi + alibi = (tl.arange(0, BLOCK_N)[None, :] + alibi_start_k - + alibi_start_q[:, None]) * alibi_slope + alibi = tl.where( + (alibi <= 0) & (alibi_start_q[:, None] < cur_batch_seq_len), + alibi, float("-inf")) + qk += alibi + alibi_start_k += BLOCK_N + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + m_i_new = tl.maximum(m_i, m_ij) + p = tl.math.exp(qk - m_i_new[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + + alpha = tl.math.exp(m_i - m_i_new) + l_i_new = alpha * l_i + l_ij + # -- update output accumulator -- + # scale p + # scale acc + acc_scale = alpha + # acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < + cur_batch_seq_len - cur_batch_ctx_len, + other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v, allow_tf32=False) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + + acc = acc / l_i[:, None] + + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + offs_d[None, :] * stride_od) + out_ptrs = Out + off_o + tl.store(out_ptrs, + acc, + mask=offs_m[:, None] < cur_batch_seq_len - cur_batch_ctx_len) + return + + @torch.inference_mode() + def context_attention_fwd(q, + k, + v, + o, + k_cache, + v_cache, + b_loc, + b_start_loc, + b_seq_len, + b_ctx_len, + max_input_len, + alibi_slopes=None): + + cap = torch.cuda.get_device_capability() + BLOCK = 128 if cap[0] >= 8 else 64 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + num_queries_per_kv = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + num_warps = 8 if Lk <= 64 else 8 + if alibi_slopes is not None: + _fwd_kernel_alibi[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + alibi_slopes, + v_cache.shape[3], + 8, + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4 + ), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + _fwd_kernel[grid]( + q, + k, + v, + k_cache, + v_cache, + b_loc, + sm_scale, + b_start_loc, + b_seq_len, + b_ctx_len, + v_cache.shape[3], + 8, + o, + b_loc.stride(0), + b_loc.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + k_cache.stride(0), + k_cache.stride(1), + k_cache.stride(2), + k_cache.stride(3), + k_cache.stride( + 4), #[num_blocks, num_kv_heads, head_size/x, block_size, x] + v_cache.stride(0), + v_cache.stride(1), + v_cache.stride(2), + v_cache.stride( + 3), #[num_blocks, num_kv_heads, head_size, block_size] + num_queries_per_kv=num_queries_per_kv, + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py new file mode 100644 index 0000000..6d13cf8 --- /dev/null +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -0,0 +1,151 @@ +from typing import Optional, Sequence + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.parallel_utils.utils import divide +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.utils import set_weight_attrs + +DEFAULT_VOCAB_PADDING_SIZE = 64 + + +def pad_vocab_size(vocab_size: int, + pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: + """Pad the vocab size to the given value.""" + return ((vocab_size + pad_to - 1) // pad_to) * pad_to + + +def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, + rank: int) -> Sequence[int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + +def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, + world_size: int) -> Sequence[int]: + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank) + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + Adapted from torch.nn.Embedding, note that we pad the vocabulary size to + make sure it is divisible by the number of model parallel GPUs. + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): + super().__init__() + + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.org_vocab_size = org_num_embeddings or num_embeddings + self.num_embeddings_padded = pad_vocab_size(num_embeddings, + padding_size) + self.embedding_dim = embedding_dim + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.tp_size = get_tensor_model_parallel_world_size() + # Divide the weight matrix along the vocaburaly dimension. + self.vocab_start_index, self.vocab_end_index = ( + vocab_range_from_global_vocab_size( + self.num_embeddings_padded, get_tensor_model_parallel_rank(), + self.tp_size)) + self.num_embeddings_per_partition = (self.vocab_end_index - + self.vocab_start_index) + self.weight = Parameter( + torch.empty(self.num_embeddings_per_partition, + self.embedding_dim, + dtype=params_dtype)) + set_weight_attrs(self.weight, { + "parallel_dim": 0, + "weight_loader": self.weight_loader + }) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + parallel_dim = param.parallel_dim + assert loaded_weight.shape[parallel_dim] == self.org_vocab_size + loaded_weight = loaded_weight[self.vocab_start_index:self. + vocab_end_index] + param[:loaded_weight.shape[0]].data.copy_(loaded_weight) + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + input_mask = ((input_ < self.vocab_start_index) | + (input_ >= self.vocab_end_index)) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output_parallel = F.embedding(masked_input, self.weight) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output + + +class ParallelLMHead(VocabParallelEmbedding): + """Parallelized LM head. + + Output logits weight matrices used in the Sampler. The weight and bias + tensors are padded to make sure they are divisible by the number of + model parallel GPUs. + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + bias: whether to use bias. + params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): + super().__init__(num_embeddings, embedding_dim, params_dtype, + org_num_embeddings, padding_size) + if bias: + self.bias = Parameter( + torch.empty(self.num_embeddings_per_partition, + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "parallel_dim": 0, + "weight_loader": self.weight_loader + }) + else: + self.register_parameter("bias", None) + + def forward(self, input_): + del input_ + raise RuntimeError("LMHead's weights should be used in the sampler.") diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py new file mode 100644 index 0000000..5848fc0 --- /dev/null +++ b/vllm/model_executor/model_loader.py @@ -0,0 +1,137 @@ +"""Utilities for selecting and loading models.""" +import contextlib +from typing import Type + +import torch +import torch.nn as nn + +from vllm.config import DeviceConfig, ModelConfig +from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.weight_utils import (get_quant_config, + initialize_dummy_weights) + + +@contextlib.contextmanager +def _set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def _get_model_architecture(model_config: ModelConfig) -> Type[nn.Module]: + architectures = getattr(model_config.hf_config, "architectures", []) + # Special handling for quantized Mixtral. + # FIXME(woosuk): This is a temporary hack. + if (model_config.quantization is not None + and "MixtralForCausalLM" in architectures): + architectures = ["QuantMixtralForCausalLM"] + + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def get_model(model_config: ModelConfig, device_config: DeviceConfig, + **kwargs) -> nn.Module: + lora_config = kwargs.get("lora_config", None) + model_class = _get_model_architecture(model_config) + + # Get the (maybe quantized) linear method. + linear_method = None + if model_config.quantization is not None: + quant_config = get_quant_config(model_config) + capability = (9, 0) + # capability = torch.cuda.get_device_capability() avoid capability error + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + linear_method = quant_config.get_linear_method() + + with _set_default_torch_dtype(model_config.dtype): + # Create a model instance. + # The weights will be initialized as empty tensors. + try: + # with torch.device contextmanager need torch >= 2.0 + with torch.device(device_config.device): + if hasattr(model_class, "supported_lora_modules"): + model = model_class(model_config.hf_config, linear_method, + lora_config) + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + else: + model = model_class(model_config.hf_config, linear_method) + if model_config.load_format == "dummy": + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + else: + # Load the weights from the cached or downloaded files. + model.load_weights(model_config.model, model_config.download_dir, + model_config.load_format, model_config.revision) + # for torch < 2.0 + except: + if hasattr(model_class, "supported_lora_modules"): + model = model_class(model_config.hf_config, linear_method, + lora_config) + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + else: + model = model_class(model_config.hf_config, linear_method) + model = model.cuda() + if model_config.load_format == "dummy": + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + else: + # Load the weights from the cached or downloaded files. + model.load_weights(model_config.model, model_config.download_dir, + model_config.load_format, model_config.revision) + return model.eval() + # TODO align + """ + with torch.device(device_config.device): + if hasattr(model_class, "supported_lora_modules"): + model = model_class(model_config.hf_config, linear_method, + lora_config) + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + else: + model = model_class(model_config.hf_config, linear_method) + if model_config.load_format == "dummy": + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + else: + # Load the weights from the cached or downloaded files. + model.load_weights(model_config.model, model_config.download_dir, + model_config.load_format, model_config.revision) + return model.eval() + """ \ No newline at end of file diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py new file mode 100644 index 0000000..22e0fde --- /dev/null +++ b/vllm/model_executor/models/__init__.py @@ -0,0 +1,107 @@ +import importlib +from typing import List, Optional, Type + +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.utils import is_hip, is_neuron + +logger = init_logger(__name__) + +# Architecture -> (module, class). +_MODELS = { + "AquilaModel": ("llama", "LlamaForCausalLM"), + "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 + "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b + "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b + "BloomForCausalLM": ("bloom", "BloomForCausalLM"), + "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), + "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), + "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), + "FalconForCausalLM": ("falcon", "FalconForCausalLM"), + "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), + "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), + "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), + "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), + "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), + "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), + "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + "SQLlamaForCausalLM": ("llama_smooth","LlamaForCausalLM"), + "CPMDragonflyForCausalLM": ("cpm", "CPMDragonflyForCausalLM"), + + # For decapoda-research/llama-* + "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), + "MistralForCausalLM": ("llama", "LlamaForCausalLM"), + "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + # transformers's mpt class has lower case + "MptForCausalLM": ("mpt", "MPTForCausalLM"), + "MPTForCausalLM": ("mpt", "MPTForCausalLM"), + "OLMoForCausalLM": ("olmo", "OLMoForCausalLM"), + "OPTForCausalLM": ("opt", "OPTForCausalLM"), + "OrionForCausalLM": ("orion", "OrionForCausalLM"), + "PhiForCausalLM": ("phi", "PhiForCausalLM"), + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), + "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), + "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), + "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), +} + +# Models not supported by ROCm. +_ROCM_UNSUPPORTED_MODELS = [] + +# Models partially supported by ROCm. +# Architecture -> Reason. +_ROCM_PARTIALLY_SUPPORTED_MODELS = { + "Qwen2ForCausalLM": + "Sliding window attention is not yet supported in ROCm's flash attention", + "MistralForCausalLM": + "Sliding window attention is not yet supported in ROCm's flash attention", + "MixtralForCausalLM": + "Sliding window attention is not yet supported in ROCm's flash attention", +} + +# Models not supported by Neuron. +_NEURON_SUPPORTED_MODELS = {"LlamaForCausalLM": "neuron.llama"} + + +class ModelRegistry: + + @staticmethod + def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch not in _MODELS: + return None + if is_hip(): + if model_arch in _ROCM_UNSUPPORTED_MODELS: + raise ValueError( + f"Model architecture {model_arch} is not supported by " + "ROCm for now.") + if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: + logger.warning( + f"Model architecture {model_arch} is partially supported " + "by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) + elif is_neuron(): + if model_arch not in _NEURON_SUPPORTED_MODELS: + raise ValueError( + f"Model architecture {model_arch} is not supported by " + "Neuron for now.") + + module_name, model_cls_name = _MODELS[model_arch] + if is_neuron(): + module_name = _NEURON_SUPPORTED_MODELS[model_arch] + module = importlib.import_module( + f"vllm.model_executor.models.{module_name}") + return getattr(module, model_cls_name, None) + + @staticmethod + def get_supported_archs() -> List[str]: + return list(_MODELS.keys()) + + +__all__ = [ + "ModelRegistry", +] diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py new file mode 100644 index 0000000..eae5e1e --- /dev/null +++ b/vllm/model_executor/models/baichuan.py @@ -0,0 +1,386 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BaiChuan model compatible with HuggingFace weights.""" +import math +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +class BaiChuanMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class BaiChuanAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + position_embedding: str, + rope_theta: float = 10000, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + self.postion_embedding = position_embedding + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + # pylint: disable=invalid-name + self.W_pack = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + # Create the alibi slopes and slice them. + if self.postion_embedding == "ALIBI": + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + scaling = self.head_dim**-0.5 + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes) + else: + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=self.rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = PagedAttention(self.num_heads, self.head_dim, + self.scaling) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.W_pack(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + if self.postion_embedding != "ALIBI": + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class BaiChuanDecoderLayer(nn.Module): + + def __init__(self, + config: PretrainedConfig, + position_embedding: str, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = BaiChuanAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + position_embedding=position_embedding, + rope_theta=rope_theta, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + ) + self.mlp = BaiChuanMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class BaiChuanModel(nn.Module): + + def __init__(self, + config: PretrainedConfig, + position_embedding: str, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + BaiChuanDecoderLayer(config, position_embedding, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class BaiChuanBaseForCausalLM(nn.Module): + + def __init__(self, + config, + position_embedding: str, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = BaiChuanModel(config, position_embedding, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + if name == "lm_head.weight": + # Unlike Baichuan, Baichuan2 normalizes the head weights. Refer to: + # https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat/blob/84603cde5ebffb6084e476cfaeceaf0b8b91fe54/modeling_baichuan.py#L508 + # Distinguish between Baichuan and Baichuan2 by checking the + # vocab size. This is suggested by + # https://github.com/vllm-project/vllm/pull/1022#discussion_r1325652704 + is_baichuan2 = self.config.vocab_size == 125696 + if is_baichuan2: + loaded_weight = torch.nn.functional.normalize( + loaded_weight) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + +class BaichuanForCausalLM(BaiChuanBaseForCausalLM): + """Baichuan 13B and Baichuan2 7B/13B.""" + + def __init__(self, + config, + linear_method: Optional[LinearMethodBase] = None): + if config.hidden_size in [4096,6656]: # baichuan2 7b 33b + super().__init__(config, "ROPE", linear_method) + else: # baichuan 13b, baichuan2 13b + super().__init__(config, "ALIBI", linear_method) + + +class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): + """Baichuan 7B.""" + + def __init__(self, + config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__(config, "ROPE", linear_method) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py new file mode 100644 index 0000000..4adfb6b --- /dev/null +++ b/vllm/model_executor/models/bloom.py @@ -0,0 +1,330 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/bloom/modeling_bloom.py +# Copyright 2023 The CacheFlow team. +# Copyright 2022 HuggingFace Inc. team and BigScience workshop. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only BLOOM model compatible with HuggingFace weights.""" +import math +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import BloomConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor( + 2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32, + ) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32, + ) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(start=1, + end=1 + 2 * num_remaining_heads, + step=2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + return slopes + + +class BloomAttention(nn.Module): + + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + self.total_num_heads = config.n_head + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + linear_method=linear_method, + ) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + linear_method=linear_method, + ) + + # Create the alibi slopes and slice them. + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + scaling = self.head_dim**-0.5 + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + del position_ids # Unused. + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.dense(attn_output) + return output + + +class BloomMLP(nn.Module): + + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.dense_h_to_4h = ColumnParallelLinear( + hidden_size, + 4 * hidden_size, + linear_method=linear_method, + ) + quant_config = getattr(linear_method, "quant_config", None) + self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, + hidden_size, + linear_method=linear_method, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.dense_h_to_4h(x) + x = self.gelu_impl(x) + x, _ = self.dense_4h_to_h(x) + return x + + +class BloomBlock(nn.Module): + + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + + self.input_layernorm = nn.LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + self.self_attention = BloomAttention(config, linear_method) + self.post_attention_layernorm = nn.LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) + self.mlp = BloomMLP(config, linear_method) + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + + # Layer norm post the self attention. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + # Self attention. + attention_output = self.self_attention( + position_ids=position_ids, + hidden_states=layernorm_output, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + attention_output = attention_output + residual + layernorm_output = self.post_attention_layernorm(attention_output) + + # Get residual + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = attention_output + + # MLP. + output = self.mlp(layernorm_output) + residual + return output + + +class BloomModel(nn.Module): + + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.embed_dim = config.hidden_size + + # Embedding + LN Embedding + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) + self.word_embeddings_layernorm = nn.LayerNorm( + self.embed_dim, eps=config.layer_norm_epsilon) + + # Transformer blocks + self.h = nn.ModuleList([ + BloomBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + + # Final Layer Norm + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.word_embeddings(input_ids) + hidden_states = self.word_embeddings_layernorm(hidden_states) + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer( + position_ids, + hidden_states, + kv_caches[i], + input_metadata, + ) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class BloomForCausalLM(nn.Module): + + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + self.transformer = BloomModel(config, linear_method) + self.lm_head_weight = self.transformer.word_embeddings.weight + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if name == "lm_head.weight": + continue + if not name.startswith("transformer."): + name = "transformer." + name + param = params_dict[name] + + if "query_key_value" in name: + # NOTE: BLOOM's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py new file mode 100644 index 0000000..20d0819 --- /dev/null +++ b/vllm/model_executor/models/chatglm.py @@ -0,0 +1,396 @@ +# coding=utf-8 +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +"""Inference-only ChatGLM model compatible with THUDM weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import LayerNorm + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs import ChatGLMConfig + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class GLMAttention(nn.Module): + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.multi_query_attention = config.multi_query_attention + self.total_num_kv_heads = (config.multi_query_group_num + if config.multi_query_attention else + config.num_attention_heads) + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = config.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.add_bias_linear or config.add_qkv_bias, + linear_method=linear_method, + ) + self.dense = RowParallelLinear( + self.total_num_heads * self.head_dim, + config.hidden_size, + bias=config.add_bias_linear, + linear_method=linear_method, + ) + + # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 + rope_ratio = getattr(config, "rope_ratio", 1.0) + max_positions = getattr(config, "seq_length", 8192) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim // 2, + max_position=max_positions, + base=10000 * rope_ratio, + is_neox_style=False, + ) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + ) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + key_cache, value_cache = kv_cache + context_layer = self.attn( + q, + k, + v, + key_cache, + value_cache, + input_metadata, + ) + attn_output, _ = self.dense(context_layer) + return attn_output + + +class GLMMLP(nn.Module): + """MLP. + + MLP will take the input with h hidden state, project it to 4*h + hidden dimension, perform nonlinear transformation, and project the + state back into h hidden dimension. + """ + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + + self.add_bias = config.add_bias_linear + + # Project to 4h. + self.dense_h_to_4h = MergedColumnParallelLinear( + config.hidden_size, + [config.ffn_hidden_size] * 2, + bias=config.add_bias_linear, + linear_method=linear_method, + ) + + self.activation_func = SiluAndMul() + + # Project back to h. + self.dense_4h_to_h = RowParallelLinear( + config.ffn_hidden_size, + config.hidden_size, + bias=config.add_bias_linear, + linear_method=linear_method, + ) + + def forward(self, hidden_states): + # [s, b, 4hp] + intermediate_parallel, _ = self.dense_h_to_4h(hidden_states) + intermediate_parallel = self.activation_func(intermediate_parallel) + # [s, b, h] + output, _ = self.dense_4h_to_h(intermediate_parallel) + return output + + +class GLMBlock(nn.Module): + """A single transformer layer. + + Transformer layer takes input with size [s, b, h] and returns an + output of the same size. + """ + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.apply_residual_connection_post_layernorm = ( + config.apply_residual_connection_post_layernorm) + + self.fp32_residual_connection = config.fp32_residual_connection + + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm + # Layernorm on the input data. + self.input_layernorm = layer_norm_func(config.hidden_size, + eps=config.layernorm_epsilon) + + # Self attention. + self.self_attention = GLMAttention(config, linear_method) + self.hidden_dropout = config.hidden_dropout + + # Layernorm on the attention output + self.post_attention_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) + + # MLP + self.mlp = GLMMLP(config, linear_method) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + # hidden_states: [num_tokens, h] + # Layer norm at the beginning of the transformer layer. + layernorm_output = self.input_layernorm(hidden_states) + # Self attention. + attention_output = self.self_attention( + hidden_states=layernorm_output, + position_ids=position_ids, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = hidden_states + + layernorm_input = residual + attention_output + + # Layer norm post the self attention. + layernorm_output = self.post_attention_layernorm(layernorm_input) + + # Second residual connection. + if self.apply_residual_connection_post_layernorm: + residual = layernorm_output + else: + residual = layernorm_input + + output = self.mlp(layernorm_output) + residual + + return output + + +class GLMTransformer(nn.Module): + """Transformer class.""" + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.post_layer_norm = config.post_layer_norm + + # Number of layers. + self.num_layers = config.num_layers + + # Transformer layers. + self.layers = nn.ModuleList( + [GLMBlock(config, linear_method) for i in range(self.num_layers)]) + + if self.post_layer_norm: + layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm + # Final layer norm before output. + self.final_layernorm = layer_norm_func( + config.hidden_size, eps=config.layernorm_epsilon) + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + for i in range(self.num_layers): + layer = self.layers[i] + hidden_states = layer( + hidden_states=hidden_states, + position_ids=position_ids, + kv_cache=kv_caches[i], + input_metadata=input_metadata, + ) + # Final layer norm. + if self.post_layer_norm: + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class ChatGLMModel(nn.Module): + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + + self.embedding = VocabParallelEmbedding(config.padded_vocab_size, + config.hidden_size) + + self.num_layers = config.num_layers + self.multi_query_group_num = config.multi_query_group_num + self.kv_channels = config.kv_channels + self.encoder = GLMTransformer(config, linear_method) + + self.output_layer = ParallelLMHead(config.padded_vocab_size, + config.hidden_size) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + inputs_embeds = self.embedding(input_ids) + + # Run encoder. + hidden_states = self.encoder( + hidden_states=inputs_embeds, + position_ids=position_ids, + kv_caches=kv_caches, + input_metadata=input_metadata, + ) + return hidden_states + + +class ChatGLMForCausalLM(nn.Module): + + def __init__( + self, + config: ChatGLMConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config: ChatGLMConfig = config + self.linear_method = linear_method + self.transformer = ChatGLMModel(config, linear_method) + self.lm_head_weight = self.transformer.output_layer.weight + self.sampler = Sampler(config.padded_vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + try: + # torch < 2.0 do not have "remove_duplicate=False" in named_parameters + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_pos_emb.inv_freq" in name: + continue + if "word_embeddings" in name: + name = name.replace(".word_embeddings", "") + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + except: + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_pos_emb.inv_freq" in name: + continue + if "word_embeddings" in name: + name = name.replace(".word_embeddings", "") + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + try: + param = params_dict[name] + except: + assert name == "transformer.output_layer.weight" + param = self.transformer.output_layer.weight + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/cpm.py b/vllm/model_executor/models/cpm.py new file mode 100644 index 0000000..d693025 --- /dev/null +++ b/vllm/model_executor/models/cpm.py @@ -0,0 +1,368 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import (RMSNorm) + + +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, + ParallelLMHead) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs import CPMDragonflyConfig + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class CPMMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class CPMAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class CPMDecoderLayer(nn.Module): + + def __init__( + self, + config: CPMDragonflyConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = CPMAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + ) + self.mlp = CPMMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.scale_states = config.scale_states # hidden_states + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual, scale = self.scale_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual, scale = self.scale_states) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class CPMModel(nn.Module): + + def __init__( + self, + config: CPMDragonflyConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = getattr(config,"pad_token_id",None) + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + CPMDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.scale = self.config.scale + self.scale_emb = self.config.scale_emb # embeding + self.scale_states = self.config.scale_states # hidden_states + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + if self.scale: + hidden_states = self.embed_tokens(input_ids) * self.scale_emb + else: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual, scale=self.scale_states) + return hidden_states + + +class CPMDragonflyForCausalLM(nn.Module): + + def __init__( + self, + config: CPMDragonflyConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = CPMModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + self.apply_inf = False + self.sampler_weight = None + self.scale_width = self.config.scale_width # output logits + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + self.apply_inf = input_metadata.is_prompt + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + # next_tokens = self.sampler(self.sampler_weight, + # hidden_states, + # sampling_metadata, + # apply_inf = self.apply_inf, + # index=1, + # skip_prompt=True, # skip prompt tokens when apply _apply_penalties function + # logits_scale=self.scale_width, # apply scale in sampler to avoid + # ) # an elementwise op on all outputs + next_tokens = self.sampler(self.sampler_weight, + hidden_states, + sampling_metadata, + logits_scale=self.scale_width, + ) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + self.sampler_weight = self.model.embed_tokens.weight if self.config.tie_lm_head == False else self.lm_head.weight \ No newline at end of file diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py new file mode 100644 index 0000000..abf4a46 --- /dev/null +++ b/vllm/model_executor/models/decilm.py @@ -0,0 +1,127 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 DeciAI Research Team. All rights reserved. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on MistralAI GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only DeciLM model compatible with HuggingFace weights.""" + +from typing import Optional + +import torch +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.models.llama import LlamaForCausalLM +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) + + +class DeciLMForCausalLM(LlamaForCausalLM): + """ + Implementation for https://huggingface.co/Deci/DeciLM-7b-instruct. + Based on the llama executor. + + The main difference is that DeciLM uses Variable Grouped Query Attention. + The constant number of GQA heads in the decoder is overridden with a value + per layer. + + Usually, in the HuggingFace implementation, instead of + "config.num_key_value_heads", we use + "config.num_key_value_heads_per_layer[i]" which varies. + + Currently, PagedAttention does not work well with variable GQA, so we + normalize the weights upon loading, and use uniform GQA with the max value + instead. + """ + + def __init__( + self, + config: Optional[PretrainedConfig] = None, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + config.num_key_value_heads = max(config.num_key_value_heads_per_layer) + delattr(config, "num_key_value_heads_per_layer") + super().__init__(config=config, + linear_method=linear_method, + lora_config=lora_config) + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + + if "k_proj" in name or "v_proj" in name: + loaded_weight = self._degroup_weight(loaded_weight) + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor: + hidden_size = self.config.hidden_size + head_size = self.config.hidden_size // self.config.num_attention_heads + target_num_kv_heads = self.config.num_key_value_heads + num_kv_heads = loaded_weight.shape[0] // head_size + n_repeats = target_num_kv_heads / num_kv_heads + assert n_repeats == int(n_repeats) + + n_repeats = int(n_repeats) + loaded_weight = loaded_weight.view(num_kv_heads, head_size, + hidden_size) + loaded_weight = torch.repeat_interleave(loaded_weight, + repeats=n_repeats, + dim=0) + loaded_weight = loaded_weight.reshape(target_num_kv_heads * head_size, + hidden_size) + + return loaded_weight diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py new file mode 100644 index 0000000..6dba952 --- /dev/null +++ b/vllm/model_executor/models/deepseek.py @@ -0,0 +1,444 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Deepseek model.""" +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + ReplicatedLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class DeepseekMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + reduce_results: bool = True, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method, + reduce_results=reduce_results) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class DeepseekMoE(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.n_routed_experts = config.n_routed_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.n_routed_experts}.") + + self.experts = nn.ModuleList([ + DeepseekMLP(hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + reduce_results=False) + for idx in range(self.n_routed_experts) + ]) + self.pack_params() + + self.gate = ReplicatedLinear(config.hidden_size, + self.n_routed_experts, + bias=False, + linear_method=None) + + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + reduce_results=False, + ) + + def pack_params(self): + w1 = [] + w2 = [] + for expert in self.experts: + w1.append(expert.gate_up_proj.weight) + w2.append(expert.down_proj.weight) + self.w1 = torch._utils._flatten_dense_tensors(w1) + w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) + for data, param in zip(w1s, w1): + param.data = data + self.w1 = self.w1.view(len(w1), *w1s[0].shape) + + self.w2 = torch._utils._flatten_dense_tensors(w2) + w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) + for data, param in zip(w2s, w2): + param.data = data + + self.w2 = self.w2.view(len(w2), *w2s[0].shape) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + if self.config.n_shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + # router_logits: (batch * sequence_length, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.w1, + self.w2, + router_logits, + self.top_k, + renormalize=self.config.norm_topk_prob, + inplace=True) + + if self.config.n_shared_experts is not None: + final_hidden_states = final_hidden_states + shared_output + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(batch_size, sequence_length, + hidden_dim) + + +class DeepseekAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class DeepseekDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + layer_idx: int, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = DeepseekAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + ) + if (config.n_routed_experts is not None and \ + layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekMoE(config=config, linear_method=linear_method) + else: + self.mlp = DeepseekMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class DeepseekModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + DeepseekDecoderLayer(config, + layer_idx, + linear_method=linear_method) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], input_metadata, + residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class DeepseekForCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = DeepseekModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, + cache_dir, + load_format, + revision, + fall_back_to_pt=False): + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_experts." in name) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if (("mlp.experts." in name or "mlp.shared_experts." in name) + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py new file mode 100644 index 0000000..2b5e022 --- /dev/null +++ b/vllm/model_executor/models/falcon.py @@ -0,0 +1,447 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py +# Copyright 2023 The vLLM team. +# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Falcon model.""" + +import math +from typing import List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import LayerNorm +from transformers import FalconConfig as HF_FalconConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs import RWConfig + +KVCache = Tuple[torch.Tensor, torch.Tensor] +FalconConfig = Union[HF_FalconConfig, RWConfig] + + +def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: + closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) + base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), + dtype=torch.float32) + powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != total_num_heads: + extra_base = torch.tensor( + 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), + dtype=torch.float32) + num_remaining_heads = min(closest_power_of_2, + total_num_heads - closest_power_of_2) + extra_powers = torch.arange(1, + 1 + 2 * num_remaining_heads, + 2, + dtype=torch.int32) + slopes = torch.cat( + [slopes, torch.pow(extra_base, extra_powers)], dim=0) + + return slopes + + +class FalconAttention(nn.Module): + + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.new_decoder_architecture = config.new_decoder_architecture + self.multi_query = config.multi_query + + if self.new_decoder_architecture: + self.total_num_kv_heads = config.num_kv_heads + elif self.multi_query: + self.total_num_kv_heads = 1 + else: + self.total_num_kv_heads = self.total_num_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.bias, + skip_bias_add=True, + linear_method=linear_method, + ) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=config.bias, + skip_bias_add=True, + linear_method=linear_method, + reduce_results=self.reduce_row_parallel_results) + + self.use_rotary = config.rotary + self.use_alibi = config.alibi + assert not (self.use_rotary and self.use_alibi), ( + "Rotary and alibi are mutually exclusive.") + + if self.use_rotary: + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads) + elif self.use_alibi: + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = (_get_alibi_slopes(self.total_num_heads) * + self.inv_norm_factor) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.inv_norm_factor, + num_kv_heads=self.num_kv_heads, + alibi_slopes=alibi_slopes) + else: + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.inv_norm_factor, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, bias = self.query_key_value(hidden_states) + if bias is not None: + qkv += bias + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.use_rotary: + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output, bias = self.dense(attn_output) + return attn_output, bias + + +class FalconMLP(nn.Module): + + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + + self.dense_h_to_4h = ColumnParallelLinear(hidden_size, + 4 * hidden_size, + bias=config.bias, + skip_bias_add=True, + linear_method=linear_method) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn("gelu", quant_config, 4 * hidden_size) + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + self.dense_4h_to_h = RowParallelLinear( + 4 * hidden_size, + hidden_size, + bias=config.bias, + skip_bias_add=True, + reduce_results=self.reduce_row_parallel_results, + linear_method=linear_method) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. + x, bias = self.dense_h_to_4h(x) + if bias is not None: + x += bias + x = self.act(x) + x, bias = self.dense_4h_to_h(x) + return x, bias + + +class FalconDecoderLayer(nn.Module): + + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.self_attention = FalconAttention(config, linear_method) + self.mlp = FalconMLP(config, linear_method) + self.config = config + + if config.new_decoder_architecture: + # The layer norm before self-attention + self.ln_attn = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + # The layer norm before the MLP + self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + else: + self.input_layernorm = LayerNorm(hidden_size, + eps=config.layer_norm_epsilon) + if not config.parallel_attn: + self.post_attention_layernorm = LayerNorm( + hidden_size, eps=config.layer_norm_epsilon) + + self.reduce_row_parallel_results = not (config.new_decoder_architecture + or config.parallel_attn) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + residual = hidden_states + + if self.config.new_decoder_architecture: + attention_layernorm_out = self.ln_attn(hidden_states) + mlp_layernorm_out = self.ln_mlp(hidden_states) + else: + attention_layernorm_out = self.input_layernorm(hidden_states) + + # Self attention. + attention_output, attention_bias = self.self_attention( + positions=positions, + hidden_states=attention_layernorm_out, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + if self.reduce_row_parallel_results and attention_bias is not None: + attention_output += attention_bias + + if not self.config.new_decoder_architecture: + if self.config.parallel_attn: + mlp_layernorm_out = attention_layernorm_out + else: + residual += attention_output + mlp_layernorm_out = self.post_attention_layernorm(residual) + + # MLP. + mlp_output, mlp_bias = self.mlp(mlp_layernorm_out) + if self.reduce_row_parallel_results and mlp_bias is not None: + mlp_output += mlp_bias + + if not self.reduce_row_parallel_results: + # When MLP and Attention layers are parallel, we can use + # only one all-reduce operator to reduce the results from + # both MLP and Attention layers. + mlp_output += attention_output + mlp_output = tensor_model_parallel_all_reduce(mlp_output) + if attention_bias is not None: + mlp_output += attention_bias + if mlp_bias is not None: + mlp_output += mlp_bias + + output = mlp_output + residual + return output + + +class FalconModel(nn.Module): + + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.use_alibi = config.alibi + + # Embedding + LN Embedding + self.word_embeddings = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) + + # Transformer blocks + self.h = nn.ModuleList([ + FalconDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + + # Final Layer Norm + self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.word_embeddings(input_ids) + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + ) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class FalconForCausalLM(nn.Module): + + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + self.transformer = FalconModel(config, linear_method) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + ) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.LongTensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer( + input_ids, + positions, + kv_caches, + input_metadata, + ) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + total_num_heads = self.config.num_attention_heads + if self.config.new_decoder_architecture: + total_num_kv_heads = self.config.num_kv_heads + elif self.config.multi_query: + total_num_kv_heads = 1 + else: + total_num_kv_heads = total_num_heads + num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + if "query_key_value" in name: + output_dim = getattr(param, "output_dim", None) + loaded_weight_shape = loaded_weight.shape + if output_dim is not None: + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + + (total_num_kv_heads, num_query_heads_per_kv_head + 2, + -1) + loaded_weight_shape[output_dim + 1:]) + wq = loaded_weight.narrow( + output_dim + 1, 0, + num_query_heads_per_kv_head).reshape( + *loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wk = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wv = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head + 1, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py new file mode 100644 index 0000000..0394813 --- /dev/null +++ b/vllm/model_executor/models/gemma.py @@ -0,0 +1,346 @@ +# coding=utf-8 +# Copyright 2023 The vLLM team. +# Copyright (c) Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Gemma model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import GemmaConfig + +from vllm.config import LoRAConfig +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import GeluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class GemmaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) + self.act_fn = GeluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GemmaAttention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + max_position_embeddings: int = 8192, + rope_theta: float = 10000, + linear_method: Optional[LinearMethodBase] = None) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = head_dim + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=self.rope_theta, + is_neox_style=True, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class GemmaDecoderLayer(nn.Module): + + def __init__( + self, + config: GemmaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = GemmaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + head_dim=config.head_dim, + max_position_embeddings=config.max_position_embeddings, + rope_theta=config.rope_theta, + linear_method=linear_method, + ) + self.mlp = GemmaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + linear_method=linear_method, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class GemmaModel(nn.Module): + + def __init__( + self, + config: GemmaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + GemmaDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + # Normalize the embedding by sqrt(hidden_size) + hidden_states *= self.config.hidden_size**0.5 + + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class GemmaForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + # Gemma does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: GemmaConfig, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + del lora_config # Unused. + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = GemmaModel(config, linear_method) + self.sampler = Sampler(config.vocab_size) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.model.embed_tokens.weight, + hidden_states, sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params = set() + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + for (param_name, shard_name, shard_id) in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # GemmaRMSNorm is different from Llama's in that it multiplies + # (1 + weight) to the output, instead of just weight. + if "norm.weight" in name: + loaded_weight += 1.0 + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + raise RuntimeError( + "Some weights are not initialized from checkpoints: " + f"{unloaded_params}") diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py new file mode 100644 index 0000000..661da0f --- /dev/null +++ b/vllm/model_executor/models/gpt2.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py +# Copyright 2023 The vLLM team. +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-2 model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import GPT2Config + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class GPT2Attention(nn.Module): + + def __init__( + self, + config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim**-0.5 + + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + bias=True, + linear_method=linear_method, + ) + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + linear_method=linear_method, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.scale) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + key_cache, value_cache = kv_cache + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPT2MLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + linear_method=linear_method, + ) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn(config.activation_function, quant_config, + intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPT2Block(nn.Module): + + def __init__( + self, + config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPT2Attention(config, linear_method) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPT2MLP(inner_dim, config, linear_method) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class GPT2Model(nn.Module): + + def __init__( + self, + config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + assert not config.add_cross_attention + assert not config.scale_attn_by_inverse_layer_idx + assert not config.reorder_and_upcast_attn + self.embed_dim = config.hidden_size + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.h = nn.ModuleList([ + GPT2Block(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer(hidden_states, kv_caches[i], input_metadata) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPT2LMHeadModel(nn.Module): + + def __init__( + self, + config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + self.transformer = GPT2Model(config, linear_method) + self.lm_head_weight = self.transformer.wte.weight + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "lm_head.weight" in name: + # GPT-2 ties the weights of the embedding layer and the final + # linear layer. + continue + if ".attn.bias" in name or ".attn.masked_bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + if not name.startswith("transformer."): + name = "transformer." + name + param = params_dict[name] + # The HF's GPT-2 implementation uses Conv1D instead of Linear. + # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. + for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: + if conv1d_weight_name not in name: + continue + if not name.endswith(".weight"): + continue + loaded_weight = loaded_weight.t() + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py new file mode 100644 index 0000000..ef4c1d4 --- /dev/null +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -0,0 +1,279 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py +# Copyright 2023 The vLLM team. +# Copyright 2023 CTranslate2, and Michael Feil +# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPTBigCode model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import GPTBigCodeConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class GPTBigCodeAttention(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.hidden_size = config.hidden_size + total_num_heads = config.num_attention_heads + self.tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert total_num_heads % self.tensor_model_parallel_world_size == 0 + self.num_heads = (total_num_heads // + self.tensor_model_parallel_world_size) + self.head_dim = self.hidden_size // total_num_heads + self.scale = self.head_dim**-0.5 + + self.multi_query = config.multi_query + if self.multi_query: + total_num_kv_heads = 1 + self.num_kv_heads = 1 + else: + total_num_kv_heads = total_num_heads + self.num_kv_heads = self.num_heads + self.kv_dim = self.head_dim * self.num_kv_heads + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + total_num_kv_heads, + bias=True, + linear_method=linear_method, + ) + + self.c_proj = RowParallelLinear( + self.hidden_size, + self.hidden_size, + bias=True, + linear_method=linear_method, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.scale, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.split( + [ + self.hidden_size // self.tensor_model_parallel_world_size, + self.kv_dim, self.kv_dim + ], + dim=-1, + ) + key_cache, value_cache = kv_cache + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + attn_output, _ = self.c_proj(attn_output) + return attn_output + + +class GPTBigMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + self.c_fc = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=True, + linear_method=linear_method, + ) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + linear_method=linear_method, + ) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn(config.activation_function, quant_config, + intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class GPTBigCodeBlock(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.hidden_size + inner_dim = (config.n_inner if config.n_inner is not None else 4 * + hidden_size) + + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.attn = GPTBigCodeAttention(config, linear_method) + self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) + self.mlp = GPTBigMLP(inner_dim, config, linear_method) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + # residual connection + hidden_states = attn_output + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + feed_forward_hidden_states = self.mlp(hidden_states) + # residual connection + hidden_states = residual + feed_forward_hidden_states + return hidden_states + + +class GPTBigCodeModel(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + assert not config.add_cross_attention + + self.embed_dim = config.hidden_size + + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) + self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) + self.h = nn.ModuleList([ + GPTBigCodeBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer(hidden_states, kv_caches[i], input_metadata) + + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPTBigCodeForCausalLM(nn.Module): + + def __init__( + self, + config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + self.transformer = GPTBigCodeModel(config, linear_method) + self.lm_head_weight = self.transformer.wte.weight + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "lm_head.weight" in name: + continue + if ".attn.bias" in name: + # Skip attention mask. + # NOTE: "c_attn.bias" should not be skipped. + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py new file mode 100644 index 0000000..5bab30d --- /dev/null +++ b/vllm/model_executor/models/gpt_j.py @@ -0,0 +1,284 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gptj/modeling_gptj.py +# Copyright 2023 The vLLM team. +# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-J model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import GPTJConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class GPTJAttention(nn.Module): + + def __init__( + self, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_size, + self.total_num_heads, + bias=False, + linear_method=linear_method, + ) + self.out_proj = RowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=False, + linear_method=linear_method, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + scaling = self.head_size**-0.5 + assert getattr(config, "rotary", True) + assert config.rotary_dim % 2 == 0 + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.rotary_emb = get_rope( + self.head_size, + rotary_dim=config.rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + is_neox_style=False, + ) + self.attn = PagedAttention(self.num_heads, self.head_size, scaling) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + attn_output, _ = self.out_proj(attn_output) + return attn_output + + +class GPTJMLP(nn.Module): + + def __init__( + self, + intermediate_size: int, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.n_embd + self.fc_in = ColumnParallelLinear( + hidden_size, + intermediate_size, + linear_method=linear_method, + ) + self.fc_out = RowParallelLinear( + intermediate_size, + hidden_size, + linear_method=linear_method, + ) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn(config.activation_function, quant_config, + intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.fc_out(hidden_states) + return hidden_states + + +class GPTJBlock(nn.Module): + + def __init__( + self, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + inner_dim = 4 * config.n_embd if config.n_inner is None else config.n_inner + self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) + self.attn = GPTJAttention(config, linear_method) + self.mlp = GPTJMLP(inner_dim, config, linear_method) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_output = self.attn( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + mlp_output = self.mlp(hidden_states) + hidden_states = attn_output + mlp_output + residual + return hidden_states + + +class GPTJModel(nn.Module): + + def __init__( + self, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.n_embd + self.wte = VocabParallelEmbedding( + config.vocab_size, + self.embed_dim, + ) + self.h = nn.ModuleList( + [GPTJBlock(config, linear_method) for _ in range(config.n_layer)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.wte(input_ids) + for i in range(len(self.h)): + layer = self.h[i] + hidden_states = layer( + position_ids, + hidden_states, + kv_caches[i], + input_metadata, + ) + hidden_states = self.ln_f(hidden_states) + return hidden_states + + +class GPTJForCausalLM(nn.Module): + + def __init__( + self, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + assert not config.tie_word_embeddings + self.transformer = GPTJModel(config, linear_method) + self.lm_head = ParallelLMHead( + config.vocab_size, + config.n_embd, + bias=True, + ) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata, self.lm_head.bias) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "attn.bias" in name or "attn.masked_bias" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py new file mode 100644 index 0000000..8f7e106 --- /dev/null +++ b/vllm/model_executor/models/gpt_neox.py @@ -0,0 +1,294 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GPT-NeoX model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import GPTNeoXConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class GPTNeoXAttention(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + self.bias = getattr(config, "attention_bias", True) + + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + + self.query_key_value = QKVParallelLinear( + config.hidden_size, + self.head_size, + self.total_num_heads, + bias=self.bias, + linear_method=linear_method, + ) + self.dense = RowParallelLinear( + config.hidden_size, + config.hidden_size, + bias=self.bias, + linear_method=linear_method, + ) + scaling = self.head_size**-0.5 + rotary_dim = int(self.head_size * config.rotary_pct) + assert rotary_dim % 2 == 0 + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.rotary_emb = get_rope( + self.head_size, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.attn = PagedAttention(self.num_heads, self.head_size, scaling) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.query_key_value(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.dense(attn_output) + return output + + +class GPTNeoXMLP(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.dense_h_to_4h = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + linear_method=linear_method, + ) + self.dense_4h_to_h = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + linear_method=linear_method, + ) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn(config.hidden_act, quant_config, + config.intermediate_size) + + def forward(self, hidden_states): + hidden_states, _ = self.dense_h_to_4h(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.dense_4h_to_h(hidden_states) + return hidden_states + + +class GPTNeoXLayer(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.use_parallel_residual = config.use_parallel_residual + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.attention = GPTNeoXAttention(config, linear_method) + self.mlp = GPTNeoXMLP(config, linear_method) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + attn_input = self.input_layernorm(hidden_states) + attn_output = self.attention( + position_ids=position_ids, + hidden_states=attn_input, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + if self.use_parallel_residual: + # pseudocode: + # x = x + attn(ln1(x)) + mlp(ln2(x)) + mlp_input = self.post_attention_layernorm(hidden_states) + mlp_output = self.mlp(mlp_input) + hidden_states = mlp_output + attn_output + hidden_states + else: + # pseudocode: + # x = x + attn(ln1(x)) + # x = x + mlp(ln2(x)) + attn_output = attn_output + hidden_states + mlp_input = self.post_attention_layernorm(attn_output) + mlp_output = self.mlp(mlp_input) + hidden_states = mlp_output + attn_output + return hidden_states + + +class GPTNeoXModel(nn.Module): + + def __init__( + self, + config: GPTNeoXConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + + self.embed_in = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + GPTNeoXLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.final_layer_norm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_in(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + position_ids, + hidden_states, + kv_caches[i], + input_metadata, + ) + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class GPTNeoXForCausalLM(nn.Module): + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + self.gpt_neox = GPTNeoXModel(config, linear_method) + self.embed_out = ParallelLMHead( + config.vocab_size, + config.hidden_size, + ) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.gpt_neox(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.embed_out.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if ("attention.bias" in name or "attention.masked_bias" in name + or "rotary_emb.inv_freq" in name): + continue + param = params_dict[name] + + if "query_key_value" in name: + # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). + # Thus, we need weight conversion. + output_dim = getattr(param, "output_dim", None) + num_heads = self.config.num_attention_heads + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py new file mode 100644 index 0000000..ebf1d8a --- /dev/null +++ b/vllm/model_executor/models/internlm2.py @@ -0,0 +1,325 @@ +# -*- coding: utf-8 -*- +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class InternLM2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.w2 = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.w2(x) + return x + + +class InternLM2Attention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.wqkv = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.wo = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.wqkv(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.wo(attn_output) + return output + + +class InternLMDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.attention = InternLM2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + ) + self.feed_forward = InternLM2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + self.attention_norm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.attention_norm(hidden_states) + else: + hidden_states, residual = self.attention_norm( + hidden_states, residual) + hidden_states = self.attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ffn_norm(hidden_states, residual) + hidden_states = self.feed_forward(hidden_states) + return hidden_states, residual + + +class InternLM2Model(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.tok_embeddings = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + InternLMDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.tok_embeddings(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class InternLM2ForCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = InternLM2Model(config, linear_method) + self.output = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.output.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w1", 0), + ("gate_up_proj", "w3", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + if "wqkv" in name: + config = self.config + kv_groups = config.num_attention_heads // config.num_key_value_heads + head_dim = config.hidden_size // config.num_attention_heads + loaded_weight = loaded_weight.view(-1, 2 + kv_groups, + head_dim, + loaded_weight.shape[-1]) + wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], + dim=1) + wq = wq.reshape(-1, wq.shape[-1]) + wk = wk.reshape(-1, wk.shape[-1]) + wv = wv.reshape(-1, wv.shape[-1]) + weight_loader = param.weight_loader + weight_loader(param, wq, 'q') + weight_loader(param, wk, 'k') + weight_loader(param, wv, 'v') + else: + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py new file mode 100644 index 0000000..d35887c --- /dev/null +++ b/vllm/model_executor/models/llama.py @@ -0,0 +1,391 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig + +from vllm.config import LoRAConfig +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class LlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + bias: bool = False, + sliding_window: Optional[int] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=bias, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=bias, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=sliding_window) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + sliding_window = getattr(config, "sliding_window", None) + self.self_attn = LlamaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + bias=getattr(config, "bias", False), + sliding_window=sliding_window, + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class LlamaModel(nn.Module): + + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.layers = nn.ModuleList([ + LlamaDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class LlamaForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = LlamaModel(config, linear_method, lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/llama_smooth.py b/vllm/model_executor/models/llama_smooth.py new file mode 100644 index 0000000..14707c8 --- /dev/null +++ b/vllm/model_executor/models/llama_smooth.py @@ -0,0 +1,409 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import LlamaConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import DequantSiluAndMulQuant +from vllm.model_executor.layers.attention import DequantPagedAttention +from vllm.model_executor.layers.layernorm import (RMSNorm, + RMSNormQuant, + AddResidualRMSNormQuant, + DequantAddResidualRMSNormQuant) + +from vllm.model_executor.layers.quantization.smoothquant import SmoothLinearMethod + +from vllm.model_executor.layers.linear import (LinearMethodBase, + QuantMergedColumnParallelLinear, + QuantQKVParallelLinear, + QuantRowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_dequant_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, + ParallelLMHead) +from vllm.model_executor.layers.layernorm import DequantAddResidual, AddResidual +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class QuantLlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_up_proj = QuantMergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method, + skip_bias_add=True) + self.down_proj = QuantRowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method, + skip_bias_add=True) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = DequantSiluAndMulQuant() + + def forward(self, x): + scale = None + # int, half -> int32 + gate_up, _ = self.gate_up_proj(x) + # int32 -> int, scale + x, *scale = self.act_fn(gate_up) + scale = scale[0] if scale is not None else None + # int8, scale -> int32(when tp > 1, to half, scale for dequant before all reduce) + x, _ = self.down_proj(x, scale) + return x, scale + + +class QuantLlamaAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, + + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QuantQKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + skip_bias_add=True, + ) + self.o_proj = QuantRowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + skip_bias_add=True, + ) + + self.rotary_emb = get_dequant_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = DequantPagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata + ) -> torch.Tensor: + # int8 -> int32 + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + # int32 -> half + q, k, v = self.rotary_emb(positions, q, k, v, + self.qkv_proj.q_dequant_scale.item(), + self.qkv_proj.k_dequant_scale.item(), + self.qkv_proj.v_dequant_scale.item()) + k_cache, v_cache = kv_cache + scale = None + # half - > int8, scale, 添加一个per channel 量化,并返回统计的scale + attn_output, *scale = self.attn(q, k, v, k_cache, v_cache, input_metadata) + scale = scale[0] if scale is not None else None + # int8, scale -> int32(when tp > 1, to half, scale for dequant before all reduce) + output, _ = self.o_proj(attn_output, scale) + return output, scale + + +class QuantLlamaDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + self.self_attn = QuantLlamaAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + linear_method=linear_method, + ) + self.mlp = QuantLlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + self.apply_dequant_in_post = not linear_method.apply_dequant_after_row + self.input_layernorm = RMSNormQuant(config.hidden_size, + eps=config.rms_norm_eps) + if self.apply_dequant_in_post: + self.post_attention_layernorm = DequantAddResidualRMSNormQuant(config.hidden_size, + eps=config.rms_norm_eps) + self.finally_add_residual = DequantAddResidual() + else: + self.post_attention_layernorm = AddResidualRMSNormQuant(config.hidden_size, + eps=config.rms_norm_eps) + self.finally_add_residual = AddResidual() + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata + ) -> Tuple[torch.Tensor, torch.Tensor]: + # half + residual = hidden_states + # half -> int8 + hidden_states = self.input_layernorm(hidden_states) + # int8 -> int32 ,scale (when tp > 1,to half, scale, this scale is useless) + hidden_states, scale = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata + ) + + # to = 1: int32, half, scale -> int8, half (scale for dequant) + # tp > 1: half, half, scale -> int8, half + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual, scale) + # int8 -> int32, scale (when tp > 1,to half, scale, this scale is useless) + hidden_states, scale = self.mlp(hidden_states) + # ine32, half, scale -> half (when tp > 1, half, half, scale -> half) + hidden_states = self.finally_add_residual(hidden_states, residual, scale) + return hidden_states + + +class QuantLlamaModel(nn.Module): + + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + QuantLlamaDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata + ) -> torch.Tensor: + # half + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata + ) + # int32 , half, scale -> int8 + hidden_states = self.norm(hidden_states) + return hidden_states + + +class LlamaForCausalLM(nn.Module): + + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = QuantLlamaModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> SamplerOutput: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # process special params first + ("qkv_proj.q_dequant_scale", "q_proj.dequant_scale", "-1"), + ("qkv_proj.k_dequant_scale", "k_proj.dequant_scale", "-1"), + ("qkv_proj.v_dequant_scale", "v_proj.dequant_scale", "-1"), + ("act_fn.gate_dequant_scale", "gate_proj.dequant_scale", "-1"), + ("act_fn.up_dequant_scale", "up_proj.dequant_scale", "-1"), + + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + special_params_mapping = [ + ("post_attention_layernorm.dequant_scale", "self_attn.o_proj.dequant_scale"), + ("finally_add_residual.dequant_scale","mlp.down_proj.dequant_scale") + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + if 'bias' in name: + continue + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader is default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight,shard_id) + break + else: + for (param_name, weight_name) in special_params_mapping: + if weight_name not in name: + continue + # used in o_prof and down_proj when world_size > 1 + if get_tensor_model_parallel_world_size() > 1: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader is default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight,shard_id) + else: + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + if weight_loader is default_weight_loader: + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight,shard_id) + break + else: + if 'bias' not in name: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py new file mode 100644 index 0000000..0100624 --- /dev/null +++ b/vllm/model_executor/models/mixtral.py @@ -0,0 +1,454 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Mixtral model.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import MixtralConfig + +from vllm.config import LoRAConfig +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class MixtralMoE(nn.Module): + """A tensor-parallel MoE implementation for Mixtral that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.tp_size = tp_size or get_tensor_model_parallel_world_size() + self.num_total_experts = num_experts + self.top_k = top_k + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size // self.tp_size + + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + self.gate = ReplicatedLinear(self.hidden_size, + self.num_total_experts, + bias=False, + params_dtype=self.params_dtype, + linear_method=None) + + self.ws = nn.Parameter( + torch.empty(self.num_total_experts, + 2 * self.intermediate_size, + self.hidden_size, + device="cuda", + dtype=self.params_dtype)) + self.w2s = nn.Parameter( + torch.empty(self.num_total_experts, + self.hidden_size, + self.intermediate_size, + device="cuda", + dtype=self.params_dtype)) + + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + weight_name: str, expert_id: int): + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + shard_size = self.intermediate_size + shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size) + if weight_name.endswith("w1.weight"): + param_data[expert_id, 0:shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w3.weight"): + param_data[expert_id, + shard_size:2 * shard_size, :] = loaded_weight[shard, :] + if weight_name.endswith("w2.weight"): + param_data[expert_id, :, :] = loaded_weight[:, shard] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_size = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (batch * sequence_length, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = fused_moe(hidden_states, + self.ws, + self.w2s, + router_logits, + self.top_k, + renormalize=True, + inplace=True) + + if self.tp_size > 1: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) + + return final_hidden_states.view(batch_size, sequence_length, + hidden_size) + + +class MixtralAttention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + linear_method: Optional[LinearMethodBase] = None, + sliding_window: Optional[int] = None) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.sliding_window = sliding_window + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MixtralDecoderLayer(nn.Module): + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = MixtralAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + sliding_window=config.sliding_window, + linear_method=linear_method) + self.block_sparse_moe = MixtralMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) + return hidden_states, residual + + +class MixtralModel(nn.Module): + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.layers = nn.ModuleList([ + MixtralDecoderLayer(config, linear_method=linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], input_metadata, + residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class MixtralForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = MixtralModel(config, + linear_method, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = [ + # (param_name, weight_name, expert_id) + ("ws" if weight_name in ["w1", "w3"] else "w2s", + f"experts.{expert_id}.{weight_name}.weight", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, + cache_dir, + load_format, + revision, + fall_back_to_pt=False): + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for param_name, weight_name, expert_id in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + weight_name, + expert_id=expert_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py new file mode 100644 index 0000000..a8dadce --- /dev/null +++ b/vllm/model_executor/models/mixtral_quant.py @@ -0,0 +1,412 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Mixtral model.""" +from typing import List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn.functional as F + +from torch import nn +from transformers import MixtralConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + ReplicatedLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class MixtralMLP(nn.Module): + + def __init__( + self, + num_experts: int, + hidden_size: int, + intermediate_size: int, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.num_experts = num_experts + self.ffn_dim = intermediate_size + self.hidden_dim = hidden_size + + self.w1 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + linear_method=linear_method) + self.w2 = ReplicatedLinear(self.ffn_dim, + self.hidden_dim, + bias=False, + linear_method=linear_method) + self.w3 = ReplicatedLinear(self.hidden_dim, + self.ffn_dim, + bias=False, + linear_method=linear_method) + + # TODO: Use vllm's SiluAndMul + self.act_fn = nn.SiLU() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + w1_out, _ = self.w1(hidden_states) + w1_out = self.act_fn(w1_out) + w3_out, _ = self.w3(hidden_states) + current_hidden_states = w1_out * w3_out + current_hidden_states, _ = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralMoE(nn.Module): + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + self.num_total_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + if self.tp_size > self.num_total_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {self.num_total_experts}.") + # Split experts equally between ranks + self.expert_indicies = np.array_split(range( + self.num_total_experts), self.tp_size)[self.rank].tolist() + if not self.expert_indicies: + raise ValueError( + f"Rank {self.rank} has no experts assigned to it.") + + self.experts = nn.ModuleList([ + MixtralMLP(self.num_total_experts, + config.hidden_size, + config.intermediate_size, + linear_method=linear_method) + if idx in self.expert_indicies else None + for idx in range(self.num_total_experts) + ]) + self.gate = ReplicatedLinear(config.hidden_size, + self.num_total_experts, + bias=False, + linear_method=None) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits, _ = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, + self.top_k, + dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + final_hidden_states = None + for expert_idx in self.expert_indicies: + expert_layer = self.experts[expert_idx] + expert_mask = (selected_experts == expert_idx) + expert_weights = (routing_weights * expert_mask).sum(dim=-1, + keepdim=True) + + current_hidden_states = expert_layer(hidden_states).mul_( + expert_weights) + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states.add_(current_hidden_states) + + return tensor_model_parallel_all_reduce(final_hidden_states).view( + batch_size, sequence_length, hidden_dim) + + +class MixtralAttention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + linear_method: Optional[LinearMethodBase] = None, + sliding_window: Optional[int] = None) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.sliding_window = sliding_window + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class MixtralDecoderLayer(nn.Module): + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = MixtralAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + sliding_window=config.sliding_window, + linear_method=linear_method) + self.block_sparse_moe = MixtralMoE(config=config, + linear_method=linear_method) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.block_sparse_moe(hidden_states) + return hidden_states, residual + + +class MixtralModel(nn.Module): + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + MixtralDecoderLayer(config, linear_method=linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], input_metadata, + residual) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class MixtralForCausalLM(nn.Module): + + def __init__( + self, + config: MixtralConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = MixtralModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, + cache_dir, + load_format, + revision, + fall_back_to_pt=False): + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Skip experts that are not assigned to this worker. + if ("block_sparse_moe.experts." in name + and name not in params_dict): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py new file mode 100644 index 0000000..22a876e --- /dev/null +++ b/vllm/model_executor/models/mpt.py @@ -0,0 +1,298 @@ +# coding=utf-8 +# Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput +from vllm.transformers_utils.configs.mpt import MPTConfig + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +def _get_alibi_slopes( + total_num_heads: int, + alibi_bias_max: int, +) -> torch.Tensor: + next_power_of_2 = 2**math.ceil(math.log2(total_num_heads)) + m = torch.arange(1, next_power_of_2 + 1, dtype=torch.float32) + m = m.mul(alibi_bias_max / next_power_of_2) + slopes = 1.0 / torch.pow(2, m) + if next_power_of_2 != total_num_heads: + slopes = torch.concat([slopes[1::2], slopes[::2]])[:total_num_heads] + return slopes + + +class MPTAttention(nn.Module): + + def __init__( + self, + config: MPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.d_model = config.d_model + self.total_num_heads = config.n_heads + self.head_dim = self.d_model // self.total_num_heads + self.clip_qkv = config.attn_config["clip_qkv"] + self.qk_ln = config.attn_config["qk_ln"] + self.alibi_bias_max = config.attn_config["alibi_bias_max"] + if "kv_n_heads" in config.attn_config: + self.total_num_kv_heads = config.attn_config['kv_n_heads'] + else: + self.total_num_kv_heads = self.total_num_heads + assert not config.attn_config["prefix_lm"] + assert config.attn_config["alibi"] + + # pylint: disable=invalid-name + self.Wqkv = QKVParallelLinear( + self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, + self.total_num_kv_heads, + bias=not config.no_bias, + linear_method=linear_method, + ) + if self.qk_ln: + self.q_ln = nn.LayerNorm(self.d_model) + self.k_ln = nn.LayerNorm(self.d_model) + self.out_proj = RowParallelLinear( + self.d_model, + self.d_model, + bias=not config.no_bias, + linear_method=linear_method, + ) + + tp_world_size = get_tensor_model_parallel_world_size() + assert self.total_num_heads % tp_world_size == 0 + self.num_heads = self.total_num_heads // tp_world_size + + if self.total_num_kv_heads >= tp_world_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_world_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_world_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + # Create the alibi slopes and slice them. + tp_rank = get_tensor_model_parallel_rank() + head_start = tp_rank * self.num_heads + head_end = (tp_rank + 1) * self.num_heads + alibi_slopes = _get_alibi_slopes(self.total_num_heads, + self.alibi_bias_max) + alibi_slopes = alibi_slopes[head_start:head_end].tolist() + + self.head_dim = self.d_model // self.total_num_heads + scaling = self.head_dim**-0.5 + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scaling, + alibi_slopes=alibi_slopes, + num_kv_heads=self.num_kv_heads) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + del position_ids # unused. + qkv, _ = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.qk_ln: + q = self.q_ln(q) + k = self.k_ln(k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.out_proj(attn_output) + return output + + +class MPTMLP(nn.Module): + + def __init__( + self, + config: MPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.d_model + expansion_ratio = config.expansion_ratio + intermediate_size = expansion_ratio * hidden_size + self.up_proj = ColumnParallelLinear( + hidden_size, + intermediate_size, + bias=not config.no_bias, + linear_method=linear_method, + ) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn("gelu", quant_config, intermediate_size) + self.down_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=not config.no_bias, + linear_method=linear_method, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.up_proj(x) + x = self.act(x) + x, _ = self.down_proj(x) + return x + + +class MPTBlock(nn.Module): + + def __init__( + self, + config: MPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + hidden_size = config.d_model + self.norm_1 = nn.LayerNorm(hidden_size) + self.attn = MPTAttention(config, linear_method) + self.norm_2 = nn.LayerNorm(hidden_size) + self.ffn = MPTMLP(config, linear_method) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + x = self.norm_1(hidden_states) + x = self.attn( + position_ids=position_ids, + hidden_states=x, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + hidden_states = hidden_states + x + x = self.norm_2(hidden_states) + x = self.ffn(x) + hidden_states = hidden_states + x + return hidden_states + + +class MPTModel(nn.Module): + + def __init__( + self, + config: MPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + assert config.embedding_fraction == 1.0 + assert config.norm_type == "low_precision_layernorm" + + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.d_model, + ) + self.blocks = nn.ModuleList( + [MPTBlock(config, linear_method) for _ in range(config.n_layers)]) + self.norm_f = nn.LayerNorm(config.d_model) + if config.no_bias: + for module in self.modules(): + if hasattr(module, "bias") and isinstance( + module.bias, nn.Parameter): + # Remove the bias term in Linear and LayerNorm. + module.register_parameter("bias", None) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.wte(input_ids) + for i in range(len(self.blocks)): + block = self.blocks[i] + hidden_states = block( + position_ids, + hidden_states, + kv_caches[i], + input_metadata, + ) + hidden_states = self.norm_f(hidden_states) + return hidden_states + + +class MPTForCausalLM(nn.Module): + + def __init__( + self, + config: MPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + assert config.tie_word_embeddings + self.linear_method = linear_method + + self.transformer = MPTModel(config, linear_method) + self.lm_head_weight = self.transformer.wte.weight + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py new file mode 100644 index 0000000..9d56303 --- /dev/null +++ b/vllm/model_executor/models/olmo.py @@ -0,0 +1,380 @@ +# coding=utf-8 +# Adapted from +# https://github.com/allenai/OLMo/blob/v0.2.4/olmo/model.py and +# https://github.com/allenai/OLMo/blob/v0.2.4/hf_olmo/modeling_olmo.py +# Copyright 2023 The vLLM team. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# +# BSD 3-Clause License +# +# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Inference-only OLMo model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, ) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) +from vllm.sequence import SamplerOutput + +# this model must need this dependency +from hf_olmo import OLMoConfig + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class SwiGLU(nn.Module): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gate = x.chunk(2, dim=-1) + return F.silu(gate) * x + + @property + def output_multiplier(self) -> float: + return 0.5 + + +class OlmoAttention(nn.Module): + """ + This is the attention block where the output is computed as ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OLMoConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.hidden_size = config.d_model + assert config.d_model % config.n_heads == 0 + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + self.total_num_heads = self.config.n_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = self.total_num_heads // tensor_model_parallel_world_size + self.head_dim = self.hidden_size // self.total_num_heads + + # Layer norms. + self.attn_norm = nn.LayerNorm(config.d_model, + elementwise_affine=False, + bias=False) + # Attention input projection. Projects x -> (q, k, v) + self.att_proj = QKVParallelLinear( + config.d_model, + self.head_dim, + self.total_num_heads, + bias=config.include_bias, + linear_method=linear_method, + ) + + # Rotary embeddings. + if self.config.rope: + rope_theta = getattr(config, "rope_theta", 10000) + max_position_embeddings = getattr(config, + "max_position_embeddings", 8192) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.scaling = self.head_dim**-0.5 + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.scaling) + + # Attention output projection. + self.attn_out = RowParallelLinear( + config.d_model, + config.d_model, + bias=config.include_bias, + linear_method=linear_method, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.attn_norm(hidden_states) + qkv, _ = self.att_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + if self.config.rope: + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.attn_out(attn_output) + return output + + +class OlmoMLP(nn.Module): + """ + This is the MLP block where the output is computed as ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__( + self, + config: OLMoConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.hidden_size = (config.mlp_hidden_size if config.mlp_hidden_size + is not None else config.mlp_ratio * config.d_model) + + # Layer norms. + self.ff_norm = nn.LayerNorm(config.d_model, + elementwise_affine=False, + bias=False) + + # Feed-forward input projection. + self.ff_proj = ColumnParallelLinear( + config.d_model, + self.hidden_size, + bias=config.include_bias, + linear_method=linear_method, + ) + + # Activation function. + # self.act = SiluAndMul() + # self.act.output_multiplier = 0.5 + self.act = SwiGLU() + assert (self.act.output_multiplier * self.hidden_size) % 1 == 0 + + # Feed-forward output projection. + self.ff_out = RowParallelLinear( + int(self.act.output_multiplier * self.hidden_size), + config.d_model, + bias=config.include_bias, + linear_method=linear_method, + ) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + x = self.ff_norm(x) + x, _ = self.ff_proj(x) + x = self.act(x) + x, _ = self.ff_out(x) + x = og_x + x + + return x + + +class OlmoBlock(nn.Module): + """ + This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))`` + (plus another skip connection). + """ + + def __init__(self, + config: OLMoConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + # Attention block. + self.attn = OlmoAttention(config, linear_method) + + # MLP block. + self.mlp = OlmoMLP(config, linear_method) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + # Attention block. + og_x = hidden_states + x = self.attn(positions, hidden_states, kv_cache, input_metadata) + x = x + og_x + + # MLP block. + hidden_states = self.mlp(x) + return hidden_states + + +class OlmoModel(nn.Module): + + def __init__(self, + config: OLMoConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + + self.transformer = nn.ModuleDict( + dict( + wte=VocabParallelEmbedding( + config.embedding_size or config.vocab_size, + config.d_model, + ), + ln_f=nn.LayerNorm(config.d_model, + elementwise_affine=False, + bias=False), + )) + + blocks = [ + OlmoBlock(config, linear_method) for i in range(config.n_layers) + ] + if self.config.block_group_size > 1: + raise NotImplementedError("Block group size > 1 not supported yet") + else: + self.transformer.update({"blocks": nn.ModuleList(blocks)}) + + if not config.weight_tying: + self.transformer.update({ + "ff_out": + ColumnParallelLinear( + config.d_model, + config.embedding_size or config.vocab_size, + bias=config.include_bias, + linear_method=linear_method, + ) + }) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + """ + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + x = self.transformer.wte(input_ids) # type: ignore + + # Apply blocks one-by-one. + for block_idx, block in enumerate(self.transformer.blocks): + # shape: (batch_size, seq_len, d_model) + x = block( + positions, + x, + kv_caches[block_idx], + input_metadata, + ) + + # Apply final layer norm. + # shape: (batch_size, seq_len or 1, d_model) + x = self.transformer.ln_f(x) # type: ignore + return x + + +class OLMoForCausalLM(nn.Module): + """ + Extremely barebones HF model wrapper. + """ + + def __init__(self, + config: OLMoConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = OlmoModel(config, linear_method) + self.lm_head_weight = (self.model.transformer.wte.weight + if config.weight_tying else + self.model.transformer.ff_out.weight) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + kv_caches=kv_caches, + input_metadata=input_metadata, + ) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + # attention + if ".att" in name: + name = name.replace(".att", ".attn.att") + # mlp + if ".ff" in name and "transformer.ff_out" not in name: + name = name.replace(".ff", ".mlp.ff") + # there is no bias in olmo + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py new file mode 100644 index 0000000..393b2dc --- /dev/null +++ b/vllm/model_executor/models/opt.py @@ -0,0 +1,354 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py +# Copyright 2023 The vLLM team. +# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights +# reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only OPT model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import OPTConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class OPTLearnedPositionalEmbedding(nn.Embedding): + + def __init__(self, num_embeddings: int, embedding_dim: int): + # OPT is set up so that if padding_idx is specified then offset the + # embedding ids by 2 and adjust num_embeddings appropriately. Other + # models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, positions: torch.Tensor): + return super().forward(positions + self.offset) + + +class OPTAttention(nn.Module): + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.embed_dim = embed_dim + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + total_num_heads = num_heads + assert num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = total_num_heads // tensor_model_parallel_world_size + self.head_dim = embed_dim // total_num_heads + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + embed_dim, + self.head_dim, + total_num_heads, + bias=bias, + linear_method=linear_method, + ) + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + bias=bias, + linear_method=linear_method, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + scale=self.scaling) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + key_cache, value_cache = kv_cache + attn_output = self.attn(q, k, v, key_cache, value_cache, + input_metadata) + output, _ = self.out_proj(attn_output) + return output + + +class OPTDecoderLayer(nn.Module): + + def __init__( + self, + config: OPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.self_attn = OPTAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + bias=config.enable_bias, + linear_method=linear_method, + ) + self.do_layer_norm_before = config.do_layer_norm_before + + self.self_attn_layer_norm = nn.LayerNorm( + self.embed_dim, + elementwise_affine=config.layer_norm_elementwise_affine) + self.fc1 = ColumnParallelLinear( + self.embed_dim, + config.ffn_dim, + bias=config.enable_bias, + linear_method=linear_method, + ) + quant_config = getattr(linear_method, "quant_config", None) + self.activation_fn = get_act_fn(config.activation_function, + quant_config, config.ffn_dim) + self.fc2 = RowParallelLinear( + config.ffn_dim, + self.embed_dim, + bias=config.enable_bias, + linear_method=linear_method, + ) + self.final_layer_norm = nn.LayerNorm( + self.embed_dim, + elementwise_affine=config.layer_norm_elementwise_affine) + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn(hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata) + hidden_states = residual + hidden_states + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Fully Connected + residual = hidden_states + # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention + if self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + hidden_states = residual + hidden_states + # 350m applies layer norm AFTER attention + if not self.do_layer_norm_before: + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + +class OPTDecoder(nn.Module): + + def __init__( + self, + config: OPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.max_target_positions = config.max_position_embeddings + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.word_embed_proj_dim, + ) + # Positional embeddings are replicated (not sharded). + self.embed_positions = OPTLearnedPositionalEmbedding( + config.max_position_embeddings, config.hidden_size) + + # Project out & in will be replicated if they exist. + if config.word_embed_proj_dim != config.hidden_size: + self.project_out = ReplicatedLinear(config.hidden_size, + config.word_embed_proj_dim, + bias=False, + linear_method=linear_method) + else: + self.project_out = None + + if config.word_embed_proj_dim != config.hidden_size: + self.project_in = ReplicatedLinear(config.word_embed_proj_dim, + config.hidden_size, + bias=False, + linear_method=linear_method) + else: + self.project_in = None + + # Note that the only purpose of `config._remove_final_layer_norm` is to + # keep backward compatibility with checkpoints that have been fine-tuned + # before transformers v4.20.1 + # see https://github.com/facebookresearch/metaseq/pull/164 + if config.do_layer_norm_before and not config._remove_final_layer_norm: + self.final_layer_norm = nn.LayerNorm( + config.hidden_size, + elementwise_affine=config.layer_norm_elementwise_affine) + else: + self.final_layer_norm = None + + self.layers = nn.ModuleList([ + OPTDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + inputs_embeds = self.embed_tokens(input_ids) + pos_embeds = self.embed_positions(positions) + if self.project_in is not None: + inputs_embeds, _ = self.project_in(inputs_embeds) + hidden_states = inputs_embeds + pos_embeds + + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer(hidden_states, kv_caches[i], input_metadata) + + if self.final_layer_norm is not None: + hidden_states = self.final_layer_norm(hidden_states) + if self.project_out is not None: + hidden_states, _ = self.project_out(hidden_states) + return hidden_states + + +class OPTModel(nn.Module): + + def __init__( + self, + config: OPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.decoder = OPTDecoder(config, linear_method) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + return self.decoder(input_ids, positions, kv_caches, input_metadata) + + +class OPTForCausalLM(nn.Module): + + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = OPTModel(config, linear_method) + self.lm_head_weight = self.model.decoder.embed_tokens.weight + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "lm_head.weight" in name: + continue + if name.startswith("decoder."): + name = "model." + name + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py new file mode 100644 index 0000000..d143261 --- /dev/null +++ b/vllm/model_executor/models/phi.py @@ -0,0 +1,305 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/microsoft/phi-1_5/blob/main/modeling_phi.py +# Copyright 2023 The vLLM team. +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. +# +# BSD 3-Clause License +# +# Copyright (c) 2022, Tri Dao, trid@cs.stanford.edu. +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Inference-only Phi-1.5 model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class PhiAttention(nn.Module): + + def __init__(self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.total_num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_size = self.hidden_size // self.total_num_heads + + tensor_model_parallel_world_size = ( + get_tensor_model_parallel_world_size()) + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + + # pylint: disable=C0103 + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_size, + self.total_num_heads, + bias=True, + linear_method=linear_method, + ) + self.dense = RowParallelLinear( + self.hidden_size, + self.hidden_size, + linear_method=linear_method, + ) + + scaling = self.head_size**-0.5 + rotary_dim = int(config.partial_rotary_factor * + (config.hidden_size // config.num_attention_heads)) + assert rotary_dim % 2 == 0 + + # pylint: disable=C0301 + # Refer to: + # https://huggingface.co/microsoft/phi-1_5/blob/d212a789620c380ff32ca1d1ee9943a777360987/modeling_phi.py#L518 + rope_theta = 10000 + max_position_embeddings = getattr(config, "n_positions", 2048) + self.rotary_emb = get_rope( + self.head_size, + rotary_dim=rotary_dim, + max_position=max_position_embeddings, + base=rope_theta, + ) + self.attn = PagedAttention(self.num_heads, self.head_size, scaling) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(position_ids, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.dense(attn_output) + return output + + +class PhiMLP(nn.Module): + + def __init__(self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + + n_inner = getattr(config, "n_inner", None) + n_inner = n_inner if n_inner is not None else 4 * config.hidden_size + + self.fc1 = ColumnParallelLinear( + config.hidden_size, + n_inner, + linear_method=linear_method, + ) + self.fc2 = RowParallelLinear( + n_inner, + config.hidden_size, + linear_method=linear_method, + ) + quant_config = getattr(linear_method, "quant_config", None) + self.act = get_act_fn(config.hidden_act, quant_config, n_inner) + + def forward(self, hidden_states): + hidden_states, _ = self.fc1(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.fc2(hidden_states) + return hidden_states + + +class PhiLayer(nn.Module): + + def __init__(self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.self_attn = PhiAttention(config, linear_method) + self.mlp = PhiMLP(config, linear_method) + + def forward( + self, + position_ids: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_outputs = self.self_attn( + position_ids=position_ids, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_outputs + feed_forward_hidden_states + residual + return hidden_states + + +class PhiModel(nn.Module): + + def __init__(self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + self.linear_method = linear_method + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.layers = nn.ModuleList([ + PhiLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.final_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(self.config.num_hidden_layers): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + ) + + hidden_states = self.final_layernorm(hidden_states) + + return hidden_states + + +class PhiForCausalLM(nn.Module): + + def __init__(self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + self.linear_method = linear_method + + self.model = PhiModel(config, linear_method) + + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + bias=True) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + head = self.lm_head + next_tokens = self.sampler(head.weight, hidden_states, + sampling_metadata, head.bias) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v") + ] + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # pylint: disable=E1136 + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py new file mode 100644 index 0000000..37af84c --- /dev/null +++ b/vllm/model_executor/models/qwen.py @@ -0,0 +1,288 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +# Copyright (c) Alibaba Cloud. +# LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE +"""Inference-only QWen model compatible with HuggingFace weights.""" +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class QWenMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str = "silu", + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.c_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class QWenAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + max_position_embeddings: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.hidden_size = hidden_size + tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( + ) + self.total_num_heads = num_heads + assert self.total_num_heads % tensor_model_parallel_world_size == 0 + self.num_heads = (self.total_num_heads // + tensor_model_parallel_world_size) + self.head_dim = hidden_size // self.total_num_heads + self.c_attn = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + bias=True, + linear_method=linear_method, + ) + self.c_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + self.scaling = self.head_dim**-0.5 + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.chunk(chunks=3, dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + + output, _ = self.c_proj(attn_output) + return output + + +class QWenBlock(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + self.attn = QWenAttention(config.hidden_size, + config.num_attention_heads, + config.max_position_embeddings, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + linear_method=linear_method) + + self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + self.mlp = QWenMLP(config.hidden_size, + config.intermediate_size // 2, + linear_method=linear_method) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + else: + hidden_states, residual = self.ln_1(hidden_states, residual) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ln_2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class QWenModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.vocab_size = config.vocab_size + + self.wte = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.h = nn.ModuleList([ + QWenBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.wte(input_ids) + residual = None + for i in range(len(self.h)): + layer = self.h[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + residual, + ) + hidden_states, _ = self.ln_f(hidden_states, residual) + return hidden_states + + +class QWenLMHeadModel(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.config = config + self.linear_method = linear_method + self.transformer = QWenModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.transformer(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w2", 0), + ("gate_up_proj", "w1", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py new file mode 100644 index 0000000..4d0e822 --- /dev/null +++ b/vllm/model_executor/models/qwen2.py @@ -0,0 +1,340 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2 model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class Qwen2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen2Attention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + use_sliding_window: bool = False, + linear_method: Optional[LinearMethodBase] = None, + sliding_window: Optional[int] = None) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.sliding_window = sliding_window if use_sliding_window else None + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + linear_method=linear_method, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen2DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen2Config, + layer_idx: int, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + use_sliding_window = config.use_sliding_window and layer_idx < config.max_window_layers + self.self_attn = Qwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + use_sliding_window=use_sliding_window, + linear_method=linear_method, + sliding_window=config.sliding_window) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + linear_method=linear_method, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen2Model(nn.Module): + + def __init__( + self, + config: Qwen2Config, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + Qwen2DecoderLayer(config, layer_idx, linear_method) + for layer_idx in range(config.num_hidden_layers) + ]) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + residual, + ) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen2ForCausalLM(nn.Module): + + def __init__( + self, + config: Qwen2Config, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = Qwen2Model(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + try: + param = params_dict[name] + except: + assert name=="lm_head.weight" # for qwen1.5 0.5b,skip this + continue + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py new file mode 100644 index 0000000..44c57e5 --- /dev/null +++ b/vllm/model_executor/models/stablelm.py @@ -0,0 +1,303 @@ +# coding=utf-8 +# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# This code is based off the following work: +# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/modeling_stablelm_epoch.py +# https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json +"""Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class StablelmMLP(nn.Module): + + def __init__(self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = MergedColumnParallelLinear( + config.hidden_size, [config.intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=False) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class StablelmAttention(nn.Module): + + def __init__(self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None) -> None: + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + self.num_heads = self.total_num_heads // tp_size + + self.total_num_key_value_heads = config.num_key_value_heads + if self.total_num_key_value_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_key_value_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_key_value_heads == 0 + self.num_key_value_heads = max( + 1, self.total_num_key_value_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.max_position_embeddings = config.max_position_embeddings + rope_pct = getattr(config, "rope_pct", + getattr(config, "partial_rotary_factor", 1)) + self.rotary_ndims = int(self.head_dim * rope_pct) + self.scaling = self.head_dim**-0.5 + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_key_value_heads * self.head_dim + self.qkv_bias = getattr(config, "use_qkv_bias", False) + if (self.head_dim * self.num_heads * tp_size) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads}).") + + self.qkv_proj = QKVParallelLinear(self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_key_value_heads, + self.qkv_bias, + linear_method=linear_method) + self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, + self.hidden_size, + bias=False, + linear_method=linear_method) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.rotary_ndims, + max_position=self.config.max_position_embeddings, + base=self.config.rope_theta, + ) + self.attn = PagedAttention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_key_value_heads) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class StablelmDecoderLayer(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.self_attn = StablelmAttention(config) + self.mlp = StablelmMLP(config, linear_method) + norm_eps = getattr(config, "norm_eps", + getattr(config, "layer_norm_eps", 1e-05)) + self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states, residual + + +class StableLMEpochModel(nn.Module): + + def __init__(self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None) -> None: + super().__init__() + # self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + self.layers = nn.ModuleList([ + StablelmDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) + norm_eps = getattr(config, "norm_eps", + getattr(config, "layer_norm_eps", 1e-05)) + self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i], + input_metadata, + ) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class StablelmForCausalLM(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + linear_method: Optional[LinearMethodBase] = None, + ) -> None: + super().__init__() + self.config = config + self.linear_method = linear_method + self.model = StableLMEpochModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) + self.sampler = Sampler(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head.weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py new file mode 100644 index 0000000..1eda07b --- /dev/null +++ b/vllm/model_executor/models/starcoder2.py @@ -0,0 +1,310 @@ +# coding=utf-8 +# Copyright 2024 BigCode and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Starcoder2 model.""" +from typing import List, Optional, Tuple + +import torch +from torch import nn + +from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) +from vllm.model_executor.parallel_utils.parallel_state import get_tensor_model_parallel_world_size +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput + +try: + from transformers import Starcoder2Config +except ImportError: + # fallback to PretrainedConfig + # NOTE: Please install transformers from source or use transformers>=4.39.0 + from transformers import PretrainedConfig as Starcoder2Config + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class Starcoder2Attention(nn.Module): + + def __init__(self, + config: Starcoder2Config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + + self.hidden_size = config.hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = config.num_attention_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = config.num_key_value_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = self.hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = config.rope_theta + self.max_position_embeddings = config.max_position_embeddings + self.use_bias = config.use_bias + self.sliding_window = config.sliding_window + + self.qkv_proj = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=self.use_bias, + linear_method=linear_method, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + self.hidden_size, + bias=self.use_bias, + linear_method=linear_method, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = PagedAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + sliding_window=self.sliding_window, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + k_cache, v_cache = kv_cache + attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class Starcoder2MLP(nn.Module): + + def __init__(self, + config: Starcoder2Config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.c_fc = ColumnParallelLinear( + config.hidden_size, + config.intermediate_size, + bias=config.use_bias, + linear_method=linear_method, + ) + self.c_proj = RowParallelLinear( + config.intermediate_size, + config.hidden_size, + bias=config.use_bias, + linear_method=linear_method, + ) + self.act = get_act_fn(config.hidden_act, + intermediate_size=config.intermediate_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.c_proj(hidden_states) + return hidden_states + + +class Starcoder2DecoderLayer(nn.Module): + + def __init__(self, + config: Starcoder2Config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = Starcoder2Attention(config, + linear_method=linear_method) + self.mlp = Starcoder2MLP(config, linear_method=linear_method) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.norm_epsilon) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: KVCache, + input_metadata: InputMetadata, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + input_metadata=input_metadata, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Starcoder2Model(nn.Module): + + def __init__(self, + config: Starcoder2Config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + # TODO: consider padding_idx (currently removed) + self.embed_tokens = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.layers = nn.ModuleList([ + Starcoder2DecoderLayer(config, linear_method=linear_method) + for _ in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states, kv_caches[i], + input_metadata) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class Starcoder2ForCausalLM(nn.Module): + + def __init__(self, + config: Starcoder2Config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__() + self.config = config + self.model = Starcoder2Model(config, linear_method=linear_method) + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + if config.tie_word_embeddings: + self.lm_head_weight = self.model.embed_tokens.weight + else: + self.unpadded_vocab_size = config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) + self.lm_head_weight = self.lm_head.weight + self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + input_metadata) + return hidden_states + + def sample( + self, + hidden_states: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(self.lm_head_weight, hidden_states, + sampling_metadata) + return next_tokens + + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in hf_model_weights_iterator( + model_name_or_path, cache_dir, load_format, revision): + if "rotary_emb.inv_freq" in name: + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/neuron_model_loader.py b/vllm/model_executor/neuron_model_loader.py new file mode 100644 index 0000000..8d5fd82 --- /dev/null +++ b/vllm/model_executor/neuron_model_loader.py @@ -0,0 +1,66 @@ +"""Utilities for selecting and loading models.""" +from typing import Type + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import ModelConfig, DeviceConfig +from vllm.model_executor.models import ModelRegistry + +TORCH_DTYPE_TO_NEURON_AMP = { + "auto": "f32", + "half": "f16", + "float16": "f16", + "bfloat16": "bf16", + "float": "f32", + "float32": "f32", + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", +} + + +def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: + architectures = getattr(config, "architectures", []) + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return model_cls + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def get_model(model_config: ModelConfig, device_config: DeviceConfig, + **kwargs) -> nn.Module: + from transformers_neuronx.config import NeuronConfig, ContinuousBatchingConfig + + parallel_config = kwargs.get("parallel_config") + scheduler_config = kwargs.get("scheduler_config") + + model_class = _get_model_architecture(model_config.hf_config) + linear_method = None + + # Create a model instance. + model = model_class(model_config.hf_config, linear_method) + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=scheduler_config.max_num_seqs) + neuron_config = NeuronConfig( + continuous_batching=continuous_batching_config) + + # Load the weights from the cached or downloaded files. + model.load_weights( + model_config.model, + model_config.download_dir, + model_config.load_format, + model_config.revision, + tp_degree=parallel_config.neuron_tp_degree, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=[scheduler_config.max_model_len], + n_positions=[scheduler_config.max_model_len], + batch_size=scheduler_config.max_num_seqs) + + return model.eval() \ No newline at end of file diff --git a/vllm/model_executor/parallel_utils/__init__.py b/vllm/model_executor/parallel_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py new file mode 100644 index 0000000..ce69f91 --- /dev/null +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -0,0 +1,213 @@ +from collections import namedtuple +from typing import Any, Dict, List, Optional, Union + +import torch +from torch.distributed import ProcessGroup + +from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tensor_model_parallel_group, + is_cupy_nccl_enabled_for_all_reduce, +) +from vllm.model_executor.parallel_utils.custom_all_reduce import custom_all_reduce +from ixformer.contrib.torch.extension.ixformer_torch.distributed import ( + create_ixformer_group_from_pg, +) +from ixformer.distributed import all_reduce +_IXFORMER_TENSOR_MODEL_PARALLEL_GROUP = None + + +def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group. + + NOTE: This operation will be applied in-place on the input tensor if + disable_custom_all_reduce is set to True. Otherwise, this operation may or + may not be applied in place depending on whether custom all reduce is + invoked for a particular tensor, which further depends on the tensor size + and GPU topology. + + TLDR: always assume this function modifies its input, but use the return + value as the output. + """ + # Bypass the function if we are using only 1 GPU. + if get_tensor_model_parallel_world_size() == 1: + return input_ + global _IXFORMER_TENSOR_MODEL_PARALLEL_GROUP + if _IXFORMER_TENSOR_MODEL_PARALLEL_GROUP is None: + _IXFORMER_TENSOR_MODEL_PARALLEL_GROUP = create_ixformer_group_from_pg(get_tensor_model_parallel_group()) + out = custom_all_reduce(input_) + if out is not None: + return out + if is_cupy_nccl_enabled_for_all_reduce(): + # TODO: support multiple parallel groups. + cupy_utils.all_reduce(input_) + else: + all_reduce(input_,group=_IXFORMER_TENSOR_MODEL_PARALLEL_GROUP,async_op=True) + # TODO use our all reduce.. + # torch.distributed.all_reduce(input_, + # group=get_tensor_model_parallel_group()) + return input_ + + +def tensor_model_parallel_all_gather(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor( + output_tensor, input_, group=get_tensor_model_parallel_group()) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (world_size * input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + +def tensor_model_parallel_gather(input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> torch.Tensor: + """Gather the input tensor across model parallel group. + + NOTE: We assume that the input tensor is on the same device across + all the ranks. + """ + world_size = get_tensor_model_parallel_world_size() + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if get_tensor_model_parallel_rank() == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, + gather_list, + dst=dst, + group=get_tensor_model_parallel_group()) + if get_tensor_model_parallel_rank() == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + +def broadcast(input_: torch.Tensor, + src: int = 0, + group: Optional[ProcessGroup] = None): + """Broadcast the input tensor.""" + group = group or torch.distributed.group.WORLD + ranks = torch.distributed.get_process_group_ranks(group) + assert src in ranks, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + world_size = torch.distributed.get_world_size(group=group) + if world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast(input_, src=src, group=group) + return input_ + + +def broadcast_object_list(obj_list: List[Any], + src: int = 0, + group: Optional[ProcessGroup] = None): + """Broadcast the input object list.""" + group = group or torch.distributed.group.WORLD + ranks = torch.distributed.get_process_group_ranks(group) + assert src in ranks, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + world_size = torch.distributed.get_world_size(group=group) + if world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, src=src, group=group) + return obj_list + + +TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) + + +def broadcast_tensor_dict( + tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, +) -> Dict[Any, Union[torch.Tensor, Any]]: + """Broadcast the input tensor dictionary.""" + group = group or torch.distributed.group.WORLD + ranks = torch.distributed.get_process_group_ranks(group) + assert src in ranks, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + world_size = torch.distributed.get_world_size(group=group) + if world_size == 1: + return tensor_dict + + rank = torch.distributed.get_rank() + if rank == src: + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + assert value.is_cuda, ( + f"Tensor {key}: {value} is not on cuda. Currently we only " + f"support broadcasting tensors on cuda.") + metadata_list.append( + (key, TensorMetadata(value.dtype, value.size()))) + else: + metadata_list.append((key, value)) + torch.distributed.broadcast_object_list([metadata_list], + src=src, + group=group) + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = tensor_dict[key] + torch.distributed.broadcast(tensor, src=src) + else: + recv_metadata_list = [None] + torch.distributed.broadcast_object_list(recv_metadata_list, + src=src, + group=group) + metadata_list = recv_metadata_list[0] + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device="cuda") + async_handle = torch.distributed.broadcast(tensor, + src=src, + async_op=True, + group=group) + async_handles.append(async_handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict diff --git a/vllm/model_executor/parallel_utils/cupy_utils.py b/vllm/model_executor/parallel_utils/cupy_utils.py new file mode 100644 index 0000000..f8cffc0 --- /dev/null +++ b/vllm/model_executor/parallel_utils/cupy_utils.py @@ -0,0 +1,130 @@ +"""CuPy utilities for all-reduce. + +We use CuPy all-reduce instead of torch.distributed.all_reduce when capturing +CUDA graphs, because torch.distributed.all_reduce causes errors when capturing +CUDA graphs. + +NOTE: We use CuPy 12.3 since CuPy 13.0 does not support Python 3.8. +TODO: Remove this file when torch.distributed.all_reduce is fixed. +""" +import contextlib + +import torch +from torch.distributed import ReduceOp + +try: + import cupy + from cupy.cuda import nccl + from cupyx.distributed import NCCLBackend +except ImportError as e: + cupy = e + nccl = None + + class NCCLBackend: + ... + + +_OP_MAPPING = { + ReduceOp.SUM: "sum", + ReduceOp.PRODUCT: "prod", + ReduceOp.MIN: "min", + ReduceOp.MAX: "max", +} + + +class NCCLBackendWithBFloat16(NCCLBackend): + # This is enough to add bfloat16 support for most operations, + # but broadcast will fail (will require changes in compiled + # cupy code). + def _get_nccl_dtype_and_count(self, array, count=None): + nccl_dtype, count = super()._get_nccl_dtype_and_count(array, count) + torch_dtype = getattr(array, "_torch_dtype", None) + if torch_dtype is torch.bfloat16: + nccl_dtype = nccl.NCCL_BFLOAT16 + return nccl_dtype, count + + def barrier(self) -> None: + raise RuntimeError( + "Currently, CuPy NCCL barrier is not supported since the TCP " + "store is immediately stopped after the initialization.") + + +_NCCL_BACKEND = None +_WORLD_SIZE = 0 + + +def is_initialized() -> bool: + """Returns whether the NCCL backend is initialized.""" + return _NCCL_BACKEND is not None + + +@contextlib.contextmanager +def set_cupy_stream(stream: torch.cuda.Stream): + """Set the cuda stream for communication""" + cupy_stream = cupy.cuda.ExternalStream(stream.cuda_stream, + stream.device_index) + with cupy_stream: + yield + + +def init_process_group(world_size: int, rank: int, host: str, + port: int) -> None: + """Initializes the CuPy NCCL backend. + + # TODO: handle NCCL timeouts. + """ + assert not is_initialized() + + if isinstance(cupy, Exception): + raise ImportError( + "NCCLBackend is not available. Please install cupy.") from cupy + + # TODO(woosuk): Create TP and PP process groups for CuPy. + global _NCCL_BACKEND + global _WORLD_SIZE + assert world_size > 0, f"{world_size=} should be a positive integer" + assert 0 <= rank < world_size, ( + f"{rank=} should be a integer between [0, {world_size})") + + cupy.cuda.runtime.setDevice(torch.cuda.current_device()) + _NCCL_BACKEND = NCCLBackendWithBFloat16(world_size, rank, host, port) + _WORLD_SIZE = world_size + + # Stop the TCP store to prevent the deadlock issues at termination time. + # FIXME(woosuk): This is hacky. Find a more robust solution. + if rank == 0 and hasattr(_NCCL_BACKEND, "_store"): + _NCCL_BACKEND._store.stop() + + +def all_reduce(input_: torch.Tensor, op=ReduceOp.SUM) -> None: + """All-reduces the input tensor across the process group.""" + assert input_.is_cuda, f"{input_} should be a cuda tensor" + # Hack to support bfloat16 + torch_dtype = input_.dtype + if torch_dtype is torch.bfloat16: + # We need to view as float16, otherwise + # cupy will fail. This will not change + # the underlying data. + input_ = input_.view(torch.float16) + cupy_input = cupy.asarray(input_) + cupy_input._torch_dtype = torch_dtype # pylint: disable=protected-access + _NCCL_BACKEND.all_reduce(in_array=cupy_input, + out_array=cupy_input, + op=_OP_MAPPING[op]) + + +def destroy_process_group() -> None: + """Destroys the NCCL backend.""" + global _NCCL_BACKEND + global _WORLD_SIZE + _NCCL_BACKEND = None + _WORLD_SIZE = 0 + + +def get_world_size() -> int: + """Returns the world size.""" + return _WORLD_SIZE + + +def get_nccl_backend(): + return _NCCL_BACKEND diff --git a/vllm/model_executor/parallel_utils/custom_all_reduce.py b/vllm/model_executor/parallel_utils/custom_all_reduce.py new file mode 100644 index 0000000..9b5a5c3 --- /dev/null +++ b/vllm/model_executor/parallel_utils/custom_all_reduce.py @@ -0,0 +1,247 @@ +from contextlib import contextmanager +from typing import Optional + +import torch +import torch.distributed as dist + +from vllm.logger import init_logger +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank) + +try: + from vllm._C import custom_ar + # import pynvml avoid import error +except ImportError: + # For AMD GPUs + custom_ar = None + pynvml = None + +logger = init_logger(__name__) + +_CA_HANDLE = None +_IS_CAPTURING = False +_SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + + +def init_custom_ar() -> None: + global _CA_HANDLE + if _CA_HANDLE is not None: + return + rank = get_tensor_model_parallel_rank() + world_size = get_tensor_model_parallel_world_size() + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in _SUPPORTED_WORLD_SIZES: + logger.warn( + "Custom allreduce is disabled due to an unsupported world size: " + "%d. Supported world sizes: %s. To silence this warning, specify" + "disable_custom_all_reduce=True explicitly.", world_size, + str(_SUPPORTED_WORLD_SIZES)) + return + if not _can_p2p(rank, world_size): + logger.warn( + "Custom allreduce is disabled because your platform lacks GPU P2P" + " capability. To silence this warning, specify" + "disable_custom_all_reduce=True explicitly.") + return + _CA_HANDLE = CustomAllreduce(rank, world_size) + + +def begin_capture() -> None: + global _IS_CAPTURING + _IS_CAPTURING = True + + +def end_capture() -> None: + global _IS_CAPTURING + _IS_CAPTURING = False + + +def is_capturing() -> bool: + return _IS_CAPTURING and _CA_HANDLE is not None + + +def get_handle() -> Optional["CustomAllreduce"]: + return _CA_HANDLE + + +def is_initialized() -> bool: + return _CA_HANDLE is not None + + +@contextmanager +def capture(): + try: + begin_capture() + yield + finally: + end_capture() + handle = get_handle() + if handle is not None: + handle.register_graph_buffers() + + +def custom_all_reduce(input: torch.Tensor) -> Optional[torch.Tensor]: + ca_handle = get_handle() + # when custom allreduce is disabled, this will be None + if ca_handle is None: + return + if is_capturing(): + if torch.cuda.is_current_stream_capturing(): + if ca_handle.should_custom_ar(input): + return ca_handle.all_reduce_reg(input) + else: + if ca_handle.should_custom_ar(input): + # if warm up, mimic the allocation pattern + # since custom allreduce is out-of-place + return torch.empty_like(input) + else: + # note: outside of cuda graph context, + # custom allreduce incurs a cost of cudaMemcpy, which should + # be small(<=1% of overall latency) compared to the performance + # gains of using custom kernels + if ca_handle.should_custom_ar(input): + return ca_handle.all_reduce_unreg(input) + + +@contextmanager +def _nvml(): + try: + pynvml.nvmlInit() + yield + finally: + pynvml.nvmlShutdown() + + +# query if the set of gpus are fully connected by nvlink (1 hop) +@_nvml() +def _is_full_nvlink(rank, world_size): + handle = pynvml.nvmlDeviceGetHandleByIndex(rank) + for i in range(world_size): + if i != rank: + try: + link_state = pynvml.nvmlDeviceGetNvLinkState(handle, i) + if not link_state: + return False + except pynvml.NVMLError as error: + logger.info( + f"NVLink detection failed with message \"{str(error)}\". " + "This is normal if your machine has no NVLink equipped") + return False + return True + + +def _can_p2p(rank: int, world_size: int) -> bool: + for i in range(world_size): + if i == rank: + continue + if not torch.cuda.can_device_access_peer(rank, i): + return False + return True + + +class CustomAllreduce: + + # max_size: max supported allreduce size + def __init__(self, rank, world_size, max_size=8192 * 1024) -> None: + self.max_size = max_size + self.world_size = world_size + self.full_nvlink = False + self._ptr = None + self.buffer = None + if not custom_ar.is_init(): + custom_ar.init_cumtom_ar() + # TODO aling + """ + # buffers memory are owned by this Python class and passed to C++ + # meta data composes of two parts: meta data for synchronization + # (256 bytes) and a temporary buffer for storing intermediate + # allreduce results. + self.meta = torch.zeros(custom_ar.meta_size() + max_size, + dtype=torch.uint8, + device="cuda") + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer = torch.empty(max_size, dtype=torch.uint8, device="cuda") + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty(8 * 1024 * 1024, + dtype=torch.uint8, + device="cuda") + self.max_size = max_size + self.world_size = world_size + handles, offsets = self._get_ipc_meta(self.meta) + self.full_nvlink = _is_full_nvlink(rank, world_size) + self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, + handles, offsets, rank, + self.full_nvlink) + self.fast_cond = self.full_nvlink or world_size <= 2 + self.register_buffer(self.buffer) + """ + #TODO align + """ + def _get_ipc_meta(self, inp: torch.Tensor): + data = inp.untyped_storage()._share_cuda_() + shard_data = ( + data[1], # ipc handle to base ptr + data[3], # offset of base ptr + ) + return self._gather_ipc_meta(shard_data) + + def _gather_ipc_meta(self, shard_data): + all_data = [None] * self.world_size + dist.all_gather_object(all_data, shard_data) + + handles = [] + offsets = [] + for i in range(len(all_data)): + handles.append(all_data[i][0]) + offsets.append(all_data[i][1]) + return handles, offsets + + def register_buffer(self, inp: torch.Tensor): + handles, offsets = self._get_ipc_meta(inp) + custom_ar.register_buffer(self._ptr, inp, handles, offsets) + + def register_graph_buffers(self): + handle, offset = custom_ar.get_graph_buffer_ipc_meta(self._ptr) + handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) + logger.info("Registering %d cuda graph addresses", len(offset)) + custom_ar.register_graph_buffers(self._ptr, handles, offsets) + """ + + def should_custom_ar(self, inp: torch.Tensor): + return custom_ar.should_custom_ar(inp, self.max_size, self.world_size, + self.full_nvlink) + + # all reduce, assuming inp tensor is IPC registered with register_buffer, + # or, in the context of cuda graphs, register_graph_buffers + def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + custom_ar.all_reduce_reg(self._ptr, inp, out) + return out + + # all reduce, assuming inp tensor is NOT IPC registered + def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + custom_ar.all_reduce_unreg(self._ptr, inp, self.buffer, out) + return out + + def close(self): + custom_ar.dispose(self._ptr) + # TODO align + """ + if self._ptr: + custom_ar.dispose(self._ptr) + self._ptr = 0 + """ + + def __del__(self): + self.close() diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py new file mode 100644 index 0000000..c821936 --- /dev/null +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -0,0 +1,245 @@ +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""Tensor and pipeline parallel groups.""" +import contextlib + +import torch + +from vllm.model_executor.parallel_utils import cupy_utils + +# Tensor model parallel group that the current rank belongs to. +_TENSOR_MODEL_PARALLEL_GROUP = None +# Pipeline model parallel group that the current rank belongs to. +_PIPELINE_MODEL_PARALLEL_GROUP = None + +# A list of global ranks for each pipeline group to ease calculation of the +# source rank when broadcasting from the first or last pipeline stage. +_PIPELINE_GLOBAL_RANKS = None + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + + if (world_size != + tensor_model_parallel_size * pipeline_model_parallel_size): + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + + num_tensor_model_parallel_groups: int = (world_size // + tensor_model_parallel_size) + num_pipeline_model_parallel_groups: int = (world_size // + pipeline_model_parallel_size) + rank = torch.distributed.get_rank() + + # Build the tensor model-parallel groups. + global _TENSOR_MODEL_PARALLEL_GROUP + assert _TENSOR_MODEL_PARALLEL_GROUP is None, ( + "tensor model parallel group is already initialized") + for i in range(num_tensor_model_parallel_groups): + ranks = range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _TENSOR_MODEL_PARALLEL_GROUP = group + + # Build the pipeline model-parallel groups. + global _PIPELINE_MODEL_PARALLEL_GROUP + global _PIPELINE_GLOBAL_RANKS + assert _PIPELINE_MODEL_PARALLEL_GROUP is None, ( + "pipeline model parallel group is already initialized") + for i in range(num_pipeline_model_parallel_groups): + ranks = range(i, world_size, num_pipeline_model_parallel_groups) + group = torch.distributed.new_group(ranks) + if rank in ranks: + _PIPELINE_MODEL_PARALLEL_GROUP = group + _PIPELINE_GLOBAL_RANKS = ranks + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, + pipeline_model_parallel_size) + return + + assert ( + get_tensor_model_parallel_world_size() == tensor_model_parallel_size + ), ("tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. " + f"{tensor_model_parallel_size=}") + assert (get_pipeline_model_parallel_world_size( + ) == pipeline_model_parallel_size), ( + "pipeline parallel group already initialized, but of unexpected size: " + f"{get_pipeline_model_parallel_world_size()=} vs. " + f"{pipeline_model_parallel_size=}") + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return (_TENSOR_MODEL_PARALLEL_GROUP is not None + and _PIPELINE_MODEL_PARALLEL_GROUP is not None) + + +def get_tensor_model_parallel_group(): + """Get the tensor model parallel group the caller rank belongs to.""" + assert _TENSOR_MODEL_PARALLEL_GROUP is not None, ( + "tensor model parallel group is not initialized") + return _TENSOR_MODEL_PARALLEL_GROUP + + +def get_pipeline_model_parallel_group(): + """Get the pipeline model parallel group the caller rank belongs to.""" + assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, ( + "pipeline model parallel group is not initialized") + return _PIPELINE_MODEL_PARALLEL_GROUP + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return torch.distributed.get_world_size( + group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_world_size(): + """Return world size for the pipeline model parallel group.""" + return torch.distributed.get_world_size( + group=get_pipeline_model_parallel_group()) + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return torch.distributed.get_rank(group=get_tensor_model_parallel_group()) + + +def get_pipeline_model_parallel_rank(): + """Return my rank for the pipeline model parallel group.""" + return torch.distributed.get_rank( + group=get_pipeline_model_parallel_group()) + + +def get_tensor_model_parallel_src_rank(): + """Calculate the global rank corresponding to the first local rank + in the tensor model parallel group.""" + global_rank = torch.distributed.get_rank() + local_world_size = get_tensor_model_parallel_world_size() + return (global_rank // local_world_size) * local_world_size + + +def get_pipeline_model_parallel_first_rank(): + """Return the global rank of the first process in the pipeline for the + current tensor parallel group""" + assert _PIPELINE_GLOBAL_RANKS is not None, ( + "Pipeline parallel group is not initialized") + return _PIPELINE_GLOBAL_RANKS[0] + + +def get_pipeline_model_parallel_last_rank(): + """Return the global rank of the last process in the pipeline for the + current tensor parallel group""" + assert _PIPELINE_GLOBAL_RANKS is not None, ( + "Pipeline parallel group is not initialized") + last_rank_local = get_pipeline_model_parallel_world_size() - 1 + return _PIPELINE_GLOBAL_RANKS[last_rank_local] + + +def get_pipeline_model_parallel_next_rank(): + """Return the global rank that follows the caller in the pipeline""" + assert _PIPELINE_GLOBAL_RANKS is not None, ( + "Pipeline parallel group is not initialized") + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size] + + +def get_pipeline_model_parallel_prev_rank(): + """Return the global rank that precedes the caller in the pipeline""" + assert _PIPELINE_GLOBAL_RANKS is not None, ( + "Pipeline parallel group is not initialized") + rank_in_pipeline = get_pipeline_model_parallel_rank() + world_size = get_pipeline_model_parallel_world_size() + return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size] + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _TENSOR_MODEL_PARALLEL_GROUP + if _TENSOR_MODEL_PARALLEL_GROUP: + torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP) + _TENSOR_MODEL_PARALLEL_GROUP = None + global _PIPELINE_MODEL_PARALLEL_GROUP + if _PIPELINE_MODEL_PARALLEL_GROUP: + torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) + _PIPELINE_MODEL_PARALLEL_GROUP = None + global _PIPELINE_GLOBAL_RANKS + _PIPELINE_GLOBAL_RANKS = None + + # Destroy the cupy states if any. + cupy_utils.destroy_process_group() + + +# Whether to use cupy for nccl all reduce. +# We use cupy for all reduce when using CUDA graph, because torch.distributed +# is not well supported by CUDA graph. +_ENABLE_CUPY_FOR_ALL_REDUCE = False + + +@contextlib.contextmanager +def with_cupy_nccl_for_all_reduce(): + """use CuPy nccl instead of torch.distributed for all reduce""" + tp_size = get_tensor_model_parallel_world_size() + if tp_size == 1: + # No-op. + # NOTE(woosuk): We don't initialize CuPy when tp_size is 1. + yield + else: + global _ENABLE_CUPY_FOR_ALL_REDUCE + old = _ENABLE_CUPY_FOR_ALL_REDUCE + _ENABLE_CUPY_FOR_ALL_REDUCE = True + + stream = torch.cuda.current_stream() + with cupy_utils.set_cupy_stream(stream): + yield + _ENABLE_CUPY_FOR_ALL_REDUCE = old + + +def is_cupy_nccl_enabled_for_all_reduce(): + """check if CuPy nccl is enabled for all reduce""" + global _ENABLE_CUPY_FOR_ALL_REDUCE + return _ENABLE_CUPY_FOR_ALL_REDUCE diff --git a/vllm/model_executor/parallel_utils/utils.py b/vllm/model_executor/parallel_utils/utils.py new file mode 100644 index 0000000..0cd420c --- /dev/null +++ b/vllm/model_executor/parallel_utils/utils.py @@ -0,0 +1,48 @@ +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +from typing import Sequence + +import torch + + +def ensure_divisibility(numerator, denominator): + """Ensure that numerator is divisible by the denominator.""" + assert numerator % denominator == 0, "{} is not divisible by {}".format( + numerator, denominator) + + +def divide(numerator, denominator): + """Ensure that numerator is divisible by the denominator and return + the division value.""" + ensure_divisibility(numerator, denominator) + return numerator // denominator + + +def split_tensor_along_last_dim( + tensor: torch.Tensor, + num_partitions: int, + contiguous_split_chunks: bool = False, +) -> Sequence[torch.Tensor]: + """ Split a tensor along its last dimension. + + Arguments: + tensor: input tensor. + num_partitions: number of partitions to split the tensor + contiguous_split_chunks: If True, make each chunk contiguous + in memory. + + Returns: + A list of Tensors + """ + # Get the size and dimension. + last_dim = tensor.dim() - 1 + last_dim_size = divide(tensor.size()[last_dim], num_partitions) + # Split. + tensor_list = torch.split(tensor, last_dim_size, dim=last_dim) + # NOTE: torch.split does not create contiguous tensors by default. + if contiguous_split_chunks: + return tuple(chunk.contiguous() for chunk in tensor_list) + + return tensor_list diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py new file mode 100644 index 0000000..7deb808 --- /dev/null +++ b/vllm/model_executor/sampling_metadata.py @@ -0,0 +1,239 @@ +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch + +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import SequenceData +from vllm.utils import in_wsl, is_neuron + +_SAMPLING_EPS = 1e-5 + + +class SamplingMetadata: + """Metadata for input sequences. Used in sampler. + + Args: + seq_groups: List of (seq_ids, sampling_params). + seq_data: Seq_id -> SequenceData. + prompt_lens: Lengths of prompts. + selected_token_indices: Token indices selected for sampling. + categorized_sample_indices: SamplingType -> token indices to sample. + generators: List of torch.Generators to use for seeded sampling + perform_sampling: Whether to perform sampling. This option is used to + make the sampling only happens in the driver worker, and disable + sampling in other worker processes. + """ + + def __init__( + self, + seq_groups: Optional[List[Tuple[List[int], SamplingParams]]], + seq_data: Optional[Dict[int, SequenceData]], + prompt_lens: Optional[List[int]], + selected_token_indices: torch.Tensor, + categorized_sample_indices: Optional[Dict[SamplingType, torch.Tensor]], + generators: Optional[List[torch.Generator]] = None, + perform_sampling: bool = True, + ) -> None: + self.seq_groups = seq_groups + self.seq_data = seq_data + self.prompt_lens = prompt_lens + self.selected_token_indices = selected_token_indices + self.categorized_sample_indices = categorized_sample_indices + self.generators = generators + self.perform_sampling = perform_sampling + + self.num_prompts = len(prompt_lens) if prompt_lens is not None else 0 + + def __repr__(self) -> str: + return ( + "SamplingMetadata(" + f"seq_groups={self.seq_groups}, " + f"seq_data={self.seq_data}, " + f"prompt_lens={self.prompt_lens}, " + f"selected_token_indices={self.selected_token_indices}, " + f"categorized_sample_indices={self.categorized_sample_indices}), " + f"perform_sampling={self.perform_sampling})") + + +@dataclass +class SamplingTensors: + """Tensors for sampling.""" + + temperatures: torch.Tensor + top_ps: torch.Tensor + top_ks: torch.Tensor + min_ps: torch.Tensor + presence_penalties: torch.Tensor + frequency_penalties: torch.Tensor + repetition_penalties: torch.Tensor + prompt_tokens: torch.Tensor + output_tokens: torch.Tensor + + @classmethod + def from_sampling_metadata( + cls, sampling_metadata: "SamplingMetadata", vocab_size: int, + device: torch.device, + dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]: + prompt_tokens: List[List[int]] = [] + output_tokens: List[List[int]] = [] + top_ks: List[int] = [] + temperatures: List[float] = [] + top_ps: List[float] = [] + min_ps: List[float] = [] + presence_penalties: List[float] = [] + frequency_penalties: List[float] = [] + repetition_penalties: List[float] = [] + do_penalties = False + do_top_p_top_k = False + do_min_p = False + for i, seq_group in enumerate(sampling_metadata.seq_groups): + seq_ids, sampling_params = seq_group + temperature = sampling_params.temperature + p = sampling_params.presence_penalty + f = sampling_params.frequency_penalty + r = sampling_params.repetition_penalty + top_p = sampling_params.top_p + min_p = sampling_params.min_p + # k should not be greater than the vocab size. + top_k = min(sampling_params.top_k, vocab_size) + top_k = vocab_size if top_k == -1 else top_k + if temperature < _SAMPLING_EPS: + # NOTE: Zero temperature means deterministic sampling + # (i.e., greedy sampling or beam search). + # Set the temperature to 1 to avoid division by zero. + temperature = 1.0 + if not do_top_p_top_k and (top_p < 1.0 - _SAMPLING_EPS + or top_k != vocab_size): + do_top_p_top_k = True + if not do_min_p and min_p > _SAMPLING_EPS: + do_min_p = True + if not do_penalties and (abs(p) >= _SAMPLING_EPS + or abs(f) >= _SAMPLING_EPS + or abs(r - 1.0) >= _SAMPLING_EPS): + do_penalties = True + if (i < sampling_metadata.num_prompts + and sampling_params.prompt_logprobs is not None): + # For tokens in the prompt that we only need to get their logprobs + prompt_len = sampling_metadata.prompt_lens[i] + temperatures += [temperature] * (prompt_len - 1) + top_ps += [top_p] * (prompt_len - 1) + top_ks += [top_k] * (prompt_len - 1) + min_ps += [min_p] * (prompt_len - 1) + presence_penalties += [0] * (prompt_len - 1) + frequency_penalties += [0] * (prompt_len - 1) + repetition_penalties += [1] * (prompt_len - 1) + prompt_tokens.extend([] for _ in range(prompt_len - 1)) + output_tokens.extend([] for _ in range(prompt_len - 1)) + for seq_id in seq_ids: + seq_data = sampling_metadata.seq_data[seq_id] + prompt_tokens.append(seq_data.prompt_token_ids) + output_tokens.append(seq_data.output_token_ids) + temperatures += [temperature] * len(seq_ids) + top_ps += [top_p] * len(seq_ids) + top_ks += [top_k] * len(seq_ids) + min_ps += [min_p] * len(seq_ids) + presence_penalties += [p] * len(seq_ids) + frequency_penalties += [f] * len(seq_ids) + repetition_penalties += [r] * len(seq_ids) + + sampling_tensors = SamplingTensors.from_lists( + temperatures, top_ps, top_ks, min_ps, presence_penalties, + frequency_penalties, repetition_penalties, prompt_tokens, + output_tokens, vocab_size, device, dtype) + return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p) + + @classmethod + def from_lists(cls, temperatures: List[float], top_ps: List[float], + top_ks: List[int], min_ps: List[float], + presence_penalties: List[float], + frequency_penalties: List[float], + repetition_penalties: List[float], + prompt_tokens: List[List[int]], + output_tokens: List[List[int]], vocab_size: int, + device: torch.device, + dtype: torch.dtype) -> "SamplingTensors": + # Note that the performance will be very bad without + # pinned memory. + pin_memory = not in_wsl() and not is_neuron() + prompt_max_len = max(len(tokens) for tokens in prompt_tokens) + prompt_padded_tokens = [ + tokens + [vocab_size] * (prompt_max_len - len(tokens)) + for tokens in prompt_tokens + ] + output_max_len = max(len(tokens) for tokens in output_tokens) + output_padded_tokens = [ + tokens + [vocab_size] * (output_max_len - len(tokens)) + for tokens in output_tokens + ] + + temperatures_t = torch.tensor( + temperatures, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ps_t = torch.tensor( + top_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + min_ps_t = torch.tensor( + min_ps, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + presence_penalties_t = torch.tensor( + presence_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + frequency_penalties_t = torch.tensor( + frequency_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + repetition_penalties_t = torch.tensor( + repetition_penalties, + device="cpu", + dtype=dtype, + pin_memory=pin_memory, + ) + top_ks_t = torch.tensor( + top_ks, + device="cpu", + dtype=torch.int, + pin_memory=pin_memory, + ) + prompt_tensor = torch.tensor( + prompt_padded_tokens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + output_tensor = torch.tensor( + output_padded_tokens, + device="cpu", + dtype=torch.long, + pin_memory=pin_memory, + ) + # Because the memory is pinned, we can do non-blocking + # transfer to device. + return cls( + temperatures=temperatures_t.to(device=device, non_blocking=True), + top_ps=top_ps_t.to(device=device, non_blocking=True), + top_ks=top_ks_t.to(device=device, non_blocking=True), + min_ps=min_ps_t.to(device=device, non_blocking=True), + presence_penalties=presence_penalties_t.to(device=device, + non_blocking=True), + frequency_penalties=frequency_penalties_t.to(device=device, + non_blocking=True), + repetition_penalties=repetition_penalties_t.to(device=device, + non_blocking=True), + prompt_tokens=prompt_tensor.to(device=device, non_blocking=True), + output_tokens=output_tensor.to(device=device, non_blocking=True), + ) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py new file mode 100644 index 0000000..0113e3e --- /dev/null +++ b/vllm/model_executor/utils.py @@ -0,0 +1,52 @@ +"""Utils for model executor.""" +import random +import importlib +from typing import Any, Dict, Optional + +import numpy as np +import torch + +from vllm.config import DeviceConfig, ModelConfig + +DEVICE_TO_MODEL_LOADER_MAP = { + "cuda": "model_loader", + "neuron": "neuron_model_loader", +} + + +def set_random_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + +def set_weight_attrs( + weight: torch.Tensor, + weight_attrs: Optional[Dict[str, Any]], +): + """Set attributes on a weight tensor. + + This method is used to set attributes on a weight tensor. This method + will not overwrite existing attributes. + + Args: + weight: The weight tensor. + weight_attrs: A dictionary of attributes to set on the weight tensor. + """ + if weight_attrs is None: + return + for key, value in weight_attrs.items(): + assert not hasattr( + weight, key), (f"Overwriting existing tensor attribute: {key}") + setattr(weight, key, value) + + +def get_model(model_config: ModelConfig, device_config: DeviceConfig, + **kwargs) -> torch.nn.Module: + model_loader_module = DEVICE_TO_MODEL_LOADER_MAP[device_config.device_type] + imported_model_loader = importlib.import_module( + f"vllm.model_executor.{model_loader_module}") + get_model_fn = imported_model_loader.get_model + return get_model_fn(model_config, device_config, **kwargs) diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py new file mode 100644 index 0000000..bf90fd3 --- /dev/null +++ b/vllm/model_executor/weight_utils.py @@ -0,0 +1,300 @@ +"""Utilities for downloading and initializing model weights.""" +import filelock +import glob +import fnmatch +import json +import os +from collections import defaultdict +from typing import Any, Iterator, List, Optional, Tuple + +from huggingface_hub import snapshot_download, HfFileSystem +import numpy as np +from safetensors.torch import load_file, save_file, safe_open +import torch +from tqdm.auto import tqdm + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import (get_quantization_config, + QuantizationConfig) + +logger = init_logger(__name__) + + +class Disabledtqdm(tqdm): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs, disable=True) + + +def get_lock(model_name_or_path: str, cache_dir: Optional[str] = None): + lock_dir = cache_dir if cache_dir is not None else "/tmp" + lock_file_name = model_name_or_path.replace("/", "-") + ".lock" + lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name)) + return lock + + +def _shared_pointers(tensors): + ptrs = defaultdict(list) + for k, v in tensors.items(): + ptrs[v.data_ptr()].append(k) + failing = [] + for _, names in ptrs.items(): + if len(names) > 1: + failing.append(names) + return failing + + +def convert_bin_to_safetensor_file( + pt_filename: str, + sf_filename: str, +) -> None: + loaded = torch.load(pt_filename, map_location="cpu") + if "state_dict" in loaded: + loaded = loaded["state_dict"] + shared = _shared_pointers(loaded) + for shared_weights in shared: + for name in shared_weights[1:]: + loaded.pop(name) + + # For tensors to be contiguous + loaded = {k: v.contiguous() for k, v in loaded.items()} + + dirname = os.path.dirname(sf_filename) + os.makedirs(dirname, exist_ok=True) + save_file(loaded, sf_filename, metadata={"format": "pt"}) + + # check file size + sf_size = os.stat(sf_filename).st_size + pt_size = os.stat(pt_filename).st_size + if (sf_size - pt_size) / pt_size > 0.01: + raise RuntimeError(f"""The file size different is more than 1%: + - {sf_filename}: {sf_size} + - {pt_filename}: {pt_size} + """) + + # check if the tensors are the same + reloaded = load_file(sf_filename) + for k in loaded: + pt_tensor = loaded[k] + sf_tensor = reloaded[k] + if not torch.equal(pt_tensor, sf_tensor): + raise RuntimeError(f"The output tensors do not match for key {k}") + + +# TODO(woosuk): Move this to other place. +def get_quant_config(model_config: ModelConfig) -> QuantizationConfig: + quant_cls = get_quantization_config(model_config.quantization) + # Read the quantization config from the HF model config, if available. + hf_quant_config = getattr(model_config.hf_config, "quantization_config", + None) + if hf_quant_config is not None: + return quant_cls.from_config(hf_quant_config) + model_name_or_path = model_config.model + is_local = os.path.isdir(model_name_or_path) + if not is_local: + # Download the config files. + with get_lock(model_name_or_path, model_config.download_dir): + hf_folder = snapshot_download(model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=model_config.download_dir, + tqdm_class=Disabledtqdm) + else: + hf_folder = model_name_or_path + config_files = glob.glob(os.path.join(hf_folder, "*.json")) + + quant_config_files = [ + f for f in config_files if any( + f.endswith(x) for x in quant_cls.get_config_filenames()) + ] + if len(quant_config_files) == 0: + raise ValueError( + f"Cannot find the config file for {model_config.quantization}") + if len(quant_config_files) > 1: + raise ValueError( + f"Found multiple config files for {model_config.quantization}: " + f"{quant_config_files}") + + quant_config_file = quant_config_files[0] + with open(quant_config_file, "r") as f: + config = json.load(f) + return quant_cls.from_config(config) + + +def prepare_hf_model_weights( + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + fall_back_to_pt: bool = True, + revision: Optional[str] = None, +) -> Tuple[str, List[str], bool]: + # Download model weights from huggingface. + is_local = os.path.isdir(model_name_or_path) + use_safetensors = False + # Some quantized models use .pt files for storing the weights. + if load_format == "auto": + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == "safetensors": + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == "pt": + allow_patterns = ["*.pt"] + elif load_format == "npcache": + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break + + logger.info(f"Using model weights format {allow_patterns}") + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model_name_or_path, cache_dir): + hf_folder = snapshot_download(model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + tqdm_class=Disabledtqdm, + revision=revision) + else: + hf_folder = model_name_or_path + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + if not use_safetensors: + # Exclude files that are not needed for inference. + # https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233 + blacklist = [ + "training_args.bin", + "optimizer.bin", + "optimizer.pt", + "scheduler.pt", + "scaler.pt", + ] + hf_weights_files = [ + f for f in hf_weights_files + if not any(f.endswith(x) for x in blacklist) + ] + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + +def hf_model_weights_iterator( + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + fall_back_to_pt: Optional[bool] = True, +) -> Iterator[Tuple[str, torch.Tensor]]: + hf_folder, hf_weights_files, use_safetensors = prepare_hf_model_weights( + model_name_or_path, + cache_dir=cache_dir, + load_format=load_format, + fall_back_to_pt=fall_back_to_pt, + revision=revision) + + if load_format == "npcache": + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + + # Convert the model weights from torch tensors to numpy arrays for + # faster loading. + np_folder = os.path.join(hf_folder, "np") + os.makedirs(np_folder, exist_ok=True) + weight_names_file = os.path.join(np_folder, "weight_names.json") + # Use file lock to prevent multiple processes from + # dumping the same model weights to numpy at the same time. + with get_lock(model_name_or_path, cache_dir): + if not os.path.exists(weight_names_file): + weight_names = [] + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + param_path = os.path.join(np_folder, name) + with open(param_path, "wb") as f: + np.save(f, param.cpu().detach().numpy()) + weight_names.append(name) + with open(weight_names_file, "w") as f: + json.dump(weight_names, f) + + with open(weight_names_file, "r") as f: + weight_names = json.load(f) + + for name in weight_names: + param_path = os.path.join(np_folder, name) + with open(param_path, "rb") as f: + param = np.load(f) + yield name, torch.from_numpy(param) + elif use_safetensors: + for st_file in hf_weights_files: + with safe_open(st_file, framework="pt") as f: + for name in f.keys(): # noqa: SIM118 + param = f.get_tensor(name) + yield name, param + else: + for bin_file in hf_weights_files: + state = torch.load(bin_file, map_location="cpu") + for name, param in state.items(): + yield name, param + del state + torch.cuda.empty_cache() + + +def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: + """convert PySafeSlice object from safetensors to torch.Tensor + + PySafeSlice object supports indexing, which is done before loading the + actual tensor and can reduce the amount of memory being read into the + memory. However, it does not support more advanced functionalities + like `.view()` or `.t()`. Therefore, if we need to modify the loaded + tensor with these more complicated operators, we need to convert to + tensor first. + """ + if not isinstance(x, torch.Tensor): + x = x[:] + return x + + +def default_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + + +def initialize_dummy_weights( + model: torch.nn.Module, + low: float = -1e-3, + high: float = 1e-3, +) -> None: + """Initialize model weights with random values. + + The model weights must be randomly initialized for accurate performance + measurements. Additionally, the model weights should not cause NaNs in the + forward pass. We empirically found that initializing the weights with + values between -1e-3 and 1e-3 works well for most models. + """ + for param in model.state_dict().values(): + if torch.is_floating_point(param): + param.data.uniform_(low, high) \ No newline at end of file diff --git a/vllm/outputs.py b/vllm/outputs.py new file mode 100644 index 0000000..a6de2a5 --- /dev/null +++ b/vllm/outputs.py @@ -0,0 +1,141 @@ +from typing import List, Optional +import time + +from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, + SequenceStatus, RequestMetrics) +from vllm.lora.request import LoRARequest + + +class CompletionOutput: + """The output data of one completion output of a request. + + Args: + index: The index of the output in the request. + text: The generated output text. + token_ids: The token IDs of the generated output text. + cumulative_logprob: The cumulative log probability of the generated + output text. + logprobs: The log probabilities of the top probability words at each + position if the logprobs are requested. + finish_reason: The reason why the sequence is finished. + lora_request: The LoRA request that was used to generate the output. + """ + + def __init__( + self, + index: int, + text: str, + token_ids: List[int], + cumulative_logprob: float, + logprobs: Optional[SampleLogprobs], + finish_reason: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, + ) -> None: + self.index = index + self.text = text + self.token_ids = token_ids + self.cumulative_logprob = cumulative_logprob + self.logprobs = logprobs + self.finish_reason = finish_reason + self.lora_request = lora_request + + def finished(self) -> bool: + return self.finish_reason is not None + + def __repr__(self) -> str: + return (f"CompletionOutput(index={self.index}, " + f"text={self.text!r}, " + f"token_ids={self.token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"logprobs={self.logprobs}, " + f"finish_reason={self.finish_reason})") + + +class RequestOutput: + """The output data of a request to the LLM. + + Args: + request_id: The unique ID of the request. + prompt: The prompt string of the request. + prompt_token_ids: The token IDs of the prompt. + prompt_logprobs: The log probabilities to return per prompt token. + outputs: The output sequences of the request. + finished: Whether the whole request is finished. + metrics: Metrics associated with the request. + lora_request: The LoRA request that was used to generate the output. + """ + + def __init__( + self, + request_id: str, + prompt: str, + prompt_token_ids: List[int], + prompt_logprobs: Optional[PromptLogprobs], + outputs: List[CompletionOutput], + finished: bool, + metrics: Optional[RequestMetrics] = None, + lora_request: Optional[LoRARequest] = None, + ) -> None: + self.request_id = request_id + self.prompt = prompt + self.prompt_token_ids = prompt_token_ids + self.prompt_logprobs = prompt_logprobs + self.outputs = outputs + self.finished = finished + self.metrics = metrics + self.lora_request = lora_request + + @classmethod + def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": + # Get the top-n sequences. + n = seq_group.sampling_params.n + seqs = seq_group.get_seqs() + if seq_group.sampling_params.use_beam_search: + sorting_key = lambda seq: seq.get_beam_search_score( + seq_group.sampling_params.length_penalty) + else: + sorting_key = lambda seq: seq.get_cumulative_logprob() + sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) + top_n_seqs = sorted_seqs[:n] + + # Create the outputs. + outputs: List[CompletionOutput] = [] + for seq in top_n_seqs: + logprobs = seq.output_logprobs + if seq_group.sampling_params.logprobs is None: + # NOTE: We need to take care of this case because the sequence + # always has the logprobs of the sampled tokens even if the + # logprobs are not requested. + logprobs = None + finshed_reason = SequenceStatus.get_finished_reason(seq.status) + output = CompletionOutput(seqs.index(seq), seq.output_text, + seq.get_output_token_ids(), + seq.get_cumulative_logprob(), logprobs, + finshed_reason) + outputs.append(output) + + # Every sequence in the sequence group should have the same prompt. + prompt = seq_group.prompt + prompt_token_ids = seq_group.prompt_token_ids + prompt_logprobs = seq_group.prompt_logprobs + finished = seq_group.is_finished() + finished_time = time.time() if finished else None + seq_group.set_finished_time(finished_time) + return cls(seq_group.request_id, + prompt, + prompt_token_ids, + prompt_logprobs, + outputs, + finished, + seq_group.metrics, + lora_request=seq_group.lora_request) + + def __repr__(self) -> str: + return (f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"outputs={self.outputs}, " + f"finished={self.finished}, " + f"metrics={self.metrics}, " + f"lora_request={self.lora_request})") diff --git a/vllm/prefix.py b/vllm/prefix.py new file mode 100644 index 0000000..5b6e8e4 --- /dev/null +++ b/vllm/prefix.py @@ -0,0 +1,87 @@ +from typing import Dict, List, Sequence, Tuple, Optional + +from vllm.block import BlockTable + + +class Prefix: + """Data and states associated with a prefix of prompt tokens for multiple + sequence groups. + + NOTE: This feature is experimental and may be replaced with automatic + prefix caching in the future. + + Args: + token_ids: The token ids of the prefix. + block_size: The block size of the executed model. + """ + + def __init__( + self, + token_ids: Sequence[int], + block_size: int, + ) -> None: + self.token_ids = tuple(token_ids) + self.block_size = block_size + self.length = len(token_ids) + self.hash = hash(token_ids) + assert self.length % block_size == 0 + self.block_table: Optional[BlockTable] = None + self.computed = False + + @property + def allocated(self) -> bool: + return self.block_table is not None + + def get_num_blocks(self) -> int: + return self.length // self.block_size + + def get_block_numbers(self) -> List[int]: + return [block.block_number for block in self.block_table] + + def get_length(self) -> int: + return self.length + + def __hash__(self) -> int: + return self.hash + + def set_block_table(self, block_table: BlockTable) -> None: + self.block_table = block_table.copy() + + +class PrefixPool: + """Manages all the prompt prefixes. + + NOTE: This feature is experimental and may be replaced with automatic + prefix caching in the future. + + Args: + block_size: The block size of the executed model. + + Attributes: + prefixes: A list of all the prefixes. + block_size: The block size of the executed model. + """ + + def __init__( + self, + block_size: int, + ) -> None: + # TODO(zhuohan): Add a capacity limit to the prefix pool. + self.prefixes: Dict[int, Prefix] = {} + self.block_size = block_size + + def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: + new_length = len(token_ids) // self.block_size * self.block_size + return tuple(token_ids[:new_length]) + + def add_or_get_prefix(self, token_ids: Sequence[int], + lora_int_id: int) -> Optional[Prefix]: + token_ids = self._truncate_token_ids(token_ids) + if len(token_ids) == 0: + # Prefix is empty. + return None + prefix = Prefix(token_ids, self.block_size) + prefix_hash = hash((prefix, lora_int_id)) + if prefix_hash not in self.prefixes: + self.prefixes[prefix_hash] = prefix + return self.prefixes[prefix_hash] diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py new file mode 100644 index 0000000..8103f3c --- /dev/null +++ b/vllm/sampling_params.py @@ -0,0 +1,279 @@ +"""Sampling parameters for text generation.""" +import copy +from enum import IntEnum +from functools import cached_property +from typing import Callable, List, Optional, Union + +import torch + +_SAMPLING_EPS = 1e-5 + + +class SamplingType(IntEnum): + GREEDY = 0 + RANDOM = 1 + RANDOM_SEED = 2 + BEAM = 3 + + +LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor] +"""LogitsProcessor is a function that takes a list of previously generated +tokens and a tensor of the logits for the next token, and returns a modified +tensor of logits to sample from.""" + + +class SamplingParams: + """Sampling parameters for text generation. + + Overall, we follow the sampling parameters from the OpenAI text completion + API (https://platform.openai.com/docs/api-reference/completions/create). + In addition, we support beam search, which is not supported by OpenAI. + + Args: + n: Number of output sequences to return for the given prompt. + best_of: Number of output sequences that are generated from the prompt. + From these `best_of` sequences, the top `n` sequences are returned. + `best_of` must be greater than or equal to `n`. This is treated as + the beam width when `use_beam_search` is True. By default, `best_of` + is set to `n`. + presence_penalty: Float that penalizes new tokens based on whether they + appear in the generated text so far. Values > 0 encourage the model + to use new tokens, while values < 0 encourage the model to repeat + tokens. + frequency_penalty: Float that penalizes new tokens based on their + frequency in the generated text so far. Values > 0 encourage the + model to use new tokens, while values < 0 encourage the model to + repeat tokens. + repetition_penalty: Float that penalizes new tokens based on whether + they appear in the prompt and the generated text so far. Values > 1 + encourage the model to use new tokens, while values < 1 encourage + the model to repeat tokens. + temperature: Float that controls the randomness of the sampling. Lower + values make the model more deterministic, while higher values make + the model more random. Zero means greedy sampling. + top_p: Float that controls the cumulative probability of the top tokens + to consider. Must be in (0, 1]. Set to 1 to consider all tokens. + top_k: Integer that controls the number of top tokens to consider. Set + to -1 to consider all tokens. + min_p: Float that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. + Must be in [0, 1]. Set to 0 to disable this. + seed: Random seed to use for the generation. + use_beam_search: Whether to use beam search instead of sampling. + length_penalty: Float that penalizes sequences based on their length. + Used in beam search. + early_stopping: Controls the stopping condition for beam search. It + accepts the following values: `True`, where the generation stops as + soon as there are `best_of` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very + unlikely to find better candidates; `"never"`, where the beam search + procedure only stops when there cannot be better candidates + (canonical beam search algorithm). + stop: List of strings that stop the generation when they are generated. + The returned output will not contain the stop strings. + stop_token_ids: List of tokens that stop the generation when they are + generated. The returned output will contain the stop tokens unless + the stop tokens are special tokens. + include_stop_str_in_output: Whether to include the stop strings in output + text. Defaults to False. + ignore_eos: Whether to ignore the EOS token and continue generating + tokens after the EOS token is generated. + max_tokens: Maximum number of tokens to generate per output sequence. + logprobs: Number of log probabilities to return per output token. + Note that the implementation follows the OpenAI API: The return + result includes the log probabilities on the `logprobs` most likely + tokens, as well the chosen tokens. The API will always return the + log probability of the sampled token, so there may be up to + `logprobs+1` elements in the response. + prompt_logprobs: Number of log probabilities to return per prompt token. + skip_special_tokens: Whether to skip special tokens in the output. + spaces_between_special_tokens: Whether to add spaces between special + tokens in the output. Defaults to True. + logits_processors: List of functions that modify logits based on + previously generated tokens. + """ + + def __init__( + self, + n: int = 1, + best_of: Optional[int] = None, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + seed: Optional[int] = None, + use_beam_search: bool = False, + length_penalty: float = 1.0, + early_stopping: Union[bool, str] = False, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + include_stop_str_in_output: bool = False, + ignore_eos: bool = False, + max_tokens: Optional[int] = 16, + logprobs: Optional[int] = None, + prompt_logprobs: Optional[int] = None, + skip_special_tokens: bool = True, + spaces_between_special_tokens: bool = True, + logits_processors: Optional[List[LogitsProcessor]] = None, + ) -> None: + self.n = n + self.best_of = best_of if best_of is not None else n + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.repetition_penalty = repetition_penalty + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.min_p = min_p + self.seed = seed + self.use_beam_search = use_beam_search + self.length_penalty = length_penalty + self.early_stopping = early_stopping + if stop is None: + self.stop = [] + elif isinstance(stop, str): + self.stop = [stop] + else: + self.stop = list(stop) + if stop_token_ids is None: + self.stop_token_ids = [] + else: + self.stop_token_ids = list(stop_token_ids) + self.ignore_eos = ignore_eos + self.max_tokens = max_tokens + self.logprobs = logprobs + self.prompt_logprobs = prompt_logprobs + self.skip_special_tokens = skip_special_tokens + self.spaces_between_special_tokens = spaces_between_special_tokens + self.logits_processors = logits_processors + self.include_stop_str_in_output = include_stop_str_in_output + self._verify_args() + if self.use_beam_search: + self._verify_beam_search() + else: + self._verify_non_beam_search() + if self.temperature < _SAMPLING_EPS: + # Zero temperature means greedy sampling. + self.top_p = 1.0 + self.top_k = -1 + self.min_p = 0.0 + self._verify_greedy_sampling() + + def _verify_args(self) -> None: + if self.n < 1: + raise ValueError(f"n must be at least 1, got {self.n}.") + if self.best_of < self.n: + raise ValueError(f"best_of must be greater than or equal to n, " + f"got n={self.n} and best_of={self.best_of}.") + if not -2.0 <= self.presence_penalty <= 2.0: + raise ValueError("presence_penalty must be in [-2, 2], got " + f"{self.presence_penalty}.") + if not -2.0 <= self.frequency_penalty <= 2.0: + raise ValueError("frequency_penalty must be in [-2, 2], got " + f"{self.frequency_penalty}.") + if not 0.0 < self.repetition_penalty <= 2.0: + raise ValueError("repetition_penalty must be in (0, 2], got " + f"{self.repetition_penalty}.") + if self.temperature < 0.0: + raise ValueError( + f"temperature must be non-negative, got {self.temperature}.") + if not 0.0 < self.top_p <= 1.0: + raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") + if self.top_k < -1 or self.top_k == 0: + raise ValueError(f"top_k must be -1 (disable), or at least 1, " + f"got {self.top_k}.") + if not 0.0 <= self.min_p <= 1.0: + raise ValueError("min_p must be in [0, 1], got " + f"{self.min_p}.") + if self.max_tokens is not None and self.max_tokens < 1: + raise ValueError( + f"max_tokens must be at least 1, got {self.max_tokens}.") + if self.logprobs is not None and self.logprobs < 0: + raise ValueError( + f"logprobs must be non-negative, got {self.logprobs}.") + if self.prompt_logprobs is not None and self.prompt_logprobs < 0: + raise ValueError(f"prompt_logprobs must be non-negative, got " + f"{self.prompt_logprobs}.") + + def _verify_beam_search(self) -> None: + if self.best_of == 1: + raise ValueError("best_of must be greater than 1 when using beam " + f"search. Got {self.best_of}.") + if self.temperature > _SAMPLING_EPS: + raise ValueError("temperature must be 0 when using beam search.") + if self.top_p < 1.0 - _SAMPLING_EPS: + raise ValueError("top_p must be 1 when using beam search.") + if self.top_k != -1: + raise ValueError("top_k must be -1 when using beam search.") + if self.early_stopping not in [True, False, "never"]: + raise ValueError( + f"early_stopping must be True, False, or 'never', " + f"got {self.early_stopping}.") + + def _verify_non_beam_search(self) -> None: + if self.early_stopping is not False: + raise ValueError("early_stopping is not effective and must be " + "False when not using beam search.") + if (self.length_penalty < 1.0 - _SAMPLING_EPS + or self.length_penalty > 1.0 + _SAMPLING_EPS): + raise ValueError( + "length_penalty is not effective and must be the " + "default value of 1.0 when not using beam search.") + + def _verify_greedy_sampling(self) -> None: + if self.best_of > 1: + raise ValueError("best_of must be 1 when using greedy sampling." + f"Got {self.best_of}.") + + @cached_property + def sampling_type(self) -> SamplingType: + if self.use_beam_search: + return SamplingType.BEAM + if self.temperature < _SAMPLING_EPS: + return SamplingType.GREEDY + if self.seed is not None: + return SamplingType.RANDOM_SEED + return SamplingType.RANDOM + + def clone(self) -> "SamplingParams": + """Deep copy excluding LogitsProcessor objects. + + LogitsProcessor objects are excluded because they may contain an + arbitrary, nontrivial amount of data. + See https://github.com/vllm-project/vllm/issues/3087 + """ + + logit_processor_refs = None if self.logits_processors is None else { + id(lp): lp + for lp in self.logits_processors + } + return copy.deepcopy(self, memo=logit_processor_refs) + + def __repr__(self) -> str: + return ( + f"SamplingParams(n={self.n}, " + f"best_of={self.best_of}, " + f"presence_penalty={self.presence_penalty}, " + f"frequency_penalty={self.frequency_penalty}, " + f"repetition_penalty={self.repetition_penalty}, " + f"temperature={self.temperature}, " + f"top_p={self.top_p}, " + f"top_k={self.top_k}, " + f"min_p={self.min_p}, " + f"seed={self.seed}, " + f"use_beam_search={self.use_beam_search}, " + f"length_penalty={self.length_penalty}, " + f"early_stopping={self.early_stopping}, " + f"stop={self.stop}, " + f"stop_token_ids={self.stop_token_ids}, " + f"include_stop_str_in_output={self.include_stop_str_in_output}, " + f"ignore_eos={self.ignore_eos}, " + f"max_tokens={self.max_tokens}, " + f"logprobs={self.logprobs}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"skip_special_tokens={self.skip_special_tokens}, " + "spaces_between_special_tokens=" + f"{self.spaces_between_special_tokens})") diff --git a/vllm/sequence.py b/vllm/sequence.py new file mode 100644 index 0000000..040e975 --- /dev/null +++ b/vllm/sequence.py @@ -0,0 +1,497 @@ +"""Sequence and its related classes.""" +import copy +import enum +from dataclasses import dataclass +from typing import Dict, List, Optional, Union + +from vllm.block import LogicalTokenBlock +from vllm.prefix import Prefix +from vllm.sampling_params import SamplingParams +from vllm.lora.request import LoRARequest + +PromptLogprobs = List[Optional[Dict[int, float]]] +SampleLogprobs = List[Dict[int, float]] + + +class SequenceStatus(enum.Enum): + """Status of a sequence.""" + WAITING = enum.auto() + RUNNING = enum.auto() + SWAPPED = enum.auto() + FINISHED_STOPPED = enum.auto() + FINISHED_LENGTH_CAPPED = enum.auto() + FINISHED_ABORTED = enum.auto() + FINISHED_IGNORED = enum.auto() + + @staticmethod + def is_finished(status: "SequenceStatus") -> bool: + return status in [ + SequenceStatus.FINISHED_STOPPED, + SequenceStatus.FINISHED_LENGTH_CAPPED, + SequenceStatus.FINISHED_ABORTED, + SequenceStatus.FINISHED_IGNORED, + ] + + @staticmethod + def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: + if status == SequenceStatus.FINISHED_STOPPED: + finish_reason = "stop" + elif status == SequenceStatus.FINISHED_LENGTH_CAPPED: + finish_reason = "length" + elif status == SequenceStatus.FINISHED_ABORTED: + finish_reason = "abort" + elif status == SequenceStatus.FINISHED_IGNORED: + # The ignored sequences are the sequences whose prompt lengths + # are longer than the model's length cap. Therefore, the stop + # reason should also be "length" as in OpenAI API. + finish_reason = "length" + else: + finish_reason = None + return finish_reason + + +@dataclass +class RequestMetrics: + """Metrics associated with a request. + + Args: + arrival_time: The time when the request arrived. + first_scheduled_time: The time when the request was first scheduled. + first_token_time: The time when the first token was generated. + time_in_queue: The time the request spent in the queue. + finished_time: The time when the request was finished. + """ + arrival_time: float + last_token_time: float + first_scheduled_time: Optional[float] + first_token_time: Optional[float] + time_in_queue: Optional[float] + finished_time: Optional[float] = None + + +class SequenceData: + """Data associated with a sequence. + + Args: + prompt_token_ids: The token IDs of the prompt. + + Attributes: + prompt_token_ids: The token IDs of the prompt. + output_token_ids: The token IDs of the output. + cumulative_logprob: The cumulative log probability of the output. + """ + + def __init__( + self, + prompt_token_ids: List[int], + ) -> None: + self.prompt_token_ids = prompt_token_ids + self.output_token_ids: List[int] = [] + self.cumulative_logprob = 0.0 + + def append_token_id(self, token_id: int, logprob: float) -> None: + self.output_token_ids.append(token_id) + self.cumulative_logprob += logprob + + def get_len(self) -> int: + return len(self.output_token_ids) + len(self.prompt_token_ids) + + def get_prompt_len(self) -> int: + return len(self.prompt_token_ids) + + def get_output_len(self) -> int: + return len(self.output_token_ids) + + def get_token_ids(self) -> List[int]: + return self.prompt_token_ids + self.output_token_ids + + def get_last_token_id(self) -> int: + if not self.output_token_ids: + return self.prompt_token_ids[-1] + return self.output_token_ids[-1] + + def __repr__(self) -> str: + return (f"SequenceData(" + f"prompt_token_ids={self.prompt_token_ids}, " + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob})") + + +class Sequence: + """Stores the data, status, and block information of a sequence. + + Args: + seq_id: The ID of the sequence. + prompt: The prompt of the sequence. + prompt_token_ids: The token IDs of the prompt. + block_size: The block size of the sequence. Should be the same as the + block size used by the block manager and cache engine. + lora_request: LoRA request. + """ + + def __init__( + self, + seq_id: int, + prompt: str, + prompt_token_ids: List[int], + block_size: int, + lora_request: Optional[LoRARequest] = None, + ) -> None: + self.seq_id = seq_id + self.prompt = prompt + self.block_size = block_size + self.lora_request = lora_request + + self.data = SequenceData(prompt_token_ids) + self.output_logprobs: SampleLogprobs = [] + self.output_text = "" + + self.logical_token_blocks: List[LogicalTokenBlock] = [] + # Initialize the logical token blocks with the prompt token ids. + self._append_tokens_to_blocks(prompt_token_ids) + self.status = SequenceStatus.WAITING + + # Used for incremental detokenization + self.prefix_offset = 0 + self.read_offset = 0 + # Input + output tokens + self.tokens: Optional[List[str]] = None + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + + def _append_logical_block(self) -> None: + block = LogicalTokenBlock( + block_number=len(self.logical_token_blocks), + block_size=self.block_size, + ) + self.logical_token_blocks.append(block) + + def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: + cursor = 0 + while cursor < len(token_ids): + if not self.logical_token_blocks: + self._append_logical_block() + + last_block = self.logical_token_blocks[-1] + if last_block.is_full(): + self._append_logical_block() + last_block = self.logical_token_blocks[-1] + + num_empty_slots = last_block.get_num_empty_slots() + last_block.append_tokens(token_ids[cursor:cursor + + num_empty_slots]) + cursor += num_empty_slots + + def append_token_id( + self, + token_id: int, + logprobs: Dict[int, float], + ) -> None: + assert token_id in logprobs + self._append_tokens_to_blocks([token_id]) + self.output_logprobs.append(logprobs) + self.data.append_token_id(token_id, logprobs[token_id]) + + def get_len(self) -> int: + return self.data.get_len() + + def get_prompt_len(self) -> int: + return self.data.get_prompt_len() + + def get_output_len(self) -> int: + return self.data.get_output_len() + + def get_token_ids(self) -> List[int]: + return self.data.get_token_ids() + + def get_last_token_id(self) -> int: + return self.data.get_last_token_id() + + def get_output_token_ids(self) -> List[int]: + return self.data.output_token_ids + + def get_cumulative_logprob(self) -> float: + return self.data.cumulative_logprob + + def get_beam_search_score(self, + length_penalty: float = 1.0, + seq_len: Optional[int] = None, + eos_token_id: Optional[int] = None) -> float: + """Calculate the beam search score with length penalty. + + Adapted from + + https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 + """ + if seq_len is None: + seq_len = self.get_len() + # NOTE: HF implementation does not count the EOS token + # towards the length, we align with that here for testing. + if (eos_token_id is not None + and self.get_last_token_id() == eos_token_id): + seq_len -= 1 + return self.get_cumulative_logprob() / (seq_len**length_penalty) + + def is_finished(self) -> bool: + return SequenceStatus.is_finished(self.status) + + def fork(self, new_seq_id: int) -> "Sequence": + new_seq = copy.deepcopy(self) + new_seq.seq_id = new_seq_id + return new_seq + + def __repr__(self) -> str: + return (f"Sequence(seq_id={self.seq_id}, " + f"status={self.status.name}, " + f"num_blocks={len(self.logical_token_blocks)})") + + +@dataclass +class SequenceGroupState: + """Mutable state tied to a specific sequence group""" + + # torch.Generator used in seeded sampling + generator: Optional = None + + +class SequenceGroup: + """A group of sequences that are generated from the same prompt. + + Args: + request_id: The ID of the request. + seqs: The list of sequences. + sampling_params: The sampling parameters used to generate the outputs. + arrival_time: The arrival time of the request. + lora_request: LoRA request. + prefix: The prefix of the prompt of the sequence group. + """ + + def __init__( + self, + request_id: str, + seqs: List[Sequence], + sampling_params: SamplingParams, + arrival_time: float, + lora_request: Optional[LoRARequest] = None, + prefix: Optional[Prefix] = None, + ) -> None: + self.request_id = request_id + self.seqs_dict = {seq.seq_id: seq for seq in seqs} + self.sampling_params = sampling_params + self.metrics = RequestMetrics(arrival_time=arrival_time, + last_token_time=arrival_time, + first_scheduled_time=None, + first_token_time=None, + time_in_queue=None) + self.lora_request = lora_request + self.prefix: Optional[Prefix] = prefix + self.prompt_logprobs: Optional[PromptLogprobs] = None + self.state = SequenceGroupState() + + @property + def prompt(self) -> str: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).prompt + + @property + def prompt_token_ids(self) -> List[int]: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + return next(iter(self.seqs_dict.values())).data.prompt_token_ids + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + + def get_last_latency(self, now: float) -> float: + """Gets last token latency for Request level timings.""" + latency = now - self.metrics.last_token_time + self.metrics.last_token_time = now + return latency + + def maybe_set_first_token_time(self, time: float) -> None: + """Sets the first token time for Request level timings.""" + if self.metrics.first_token_time is None: + self.metrics.first_token_time = time + + def maybe_set_first_scheduled_time(self, time: float) -> None: + """Sets the first scheduled time and time in queue for Request level timings.""" + if self.metrics.first_scheduled_time is None: + self.metrics.first_scheduled_time = time + self.metrics.time_in_queue = time - self.metrics.arrival_time + + def set_finished_time(self, time: Optional[float]) -> None: + """Sets the finished time for Request level timings.""" + self.metrics.finished_time = time + + def get_max_num_running_seqs(self) -> int: + """The maximum number of sequences running in parallel in the remaining + lifetime of the request.""" + if self.sampling_params.use_beam_search: + # For beam search, maximally there will always be `best_of` beam + # candidates running in the future. + return self.sampling_params.best_of + else: + if self.sampling_params.best_of > self.num_seqs(): + # At prompt stage, the sequence group is not yet filled up + # and only have one sequence running. However, in the + # generation stage, we will have `best_of` sequences running. + return self.sampling_params.best_of + # At sampling stages, return the number of actual sequences + # that are not finished yet. + return self.num_unfinished_seqs() + + def get_seqs( + self, + status: Optional[SequenceStatus] = None, + ) -> List[Sequence]: + if status is None: + return list(self.seqs_dict.values()) + else: + return [ + seq for seq in self.seqs_dict.values() if seq.status == status + ] + + def get_unfinished_seqs(self) -> List[Sequence]: + return [ + seq for seq in self.seqs_dict.values() if not seq.is_finished() + ] + + def get_finished_seqs(self) -> List[Sequence]: + return [seq for seq in self.seqs_dict.values() if seq.is_finished()] + + def num_seqs(self, status: Optional[SequenceStatus] = None) -> int: + return len(self.get_seqs(status)) + + def num_unfinished_seqs(self) -> int: + return len(self.get_unfinished_seqs()) + + def num_finished_seqs(self) -> int: + return len(self.get_finished_seqs()) + + def find(self, seq_id: int) -> Sequence: + if seq_id not in self.seqs_dict: + raise ValueError(f"Sequence {seq_id} not found.") + return self.seqs_dict[seq_id] + + def add(self, seq: Sequence) -> None: + if seq.seq_id in self.seqs_dict: + raise ValueError(f"Sequence {seq.seq_id} already exists.") + self.seqs_dict[seq.seq_id] = seq + + def remove(self, seq_id: int) -> None: + if seq_id not in self.seqs_dict: + raise ValueError(f"Sequence {seq_id} not found.") + del self.seqs_dict[seq_id] + + def is_finished(self) -> bool: + return all(seq.is_finished() for seq in self.get_seqs()) + + def __repr__(self) -> str: + return (f"SequenceGroup(request_id={self.request_id}, " + f"sampling_params={self.sampling_params}, " + f"num_seqs={len(self.seqs_dict)})") + + +class SequenceGroupMetadata: + """Metadata for a sequence group. Used to create `InputMetadata`. + + Args: + request_id: The ID of the request. + is_prompt: Whether the request is at prompt stage. + seq_data: The sequence data. (Seq id -> sequence data) + sampling_params: The sampling parameters used to generate the outputs. + block_tables: The block tables. (Seq id -> list of physical block + numbers) + state: Internal state tied to this sequence group. + lora_request: LoRA request. + prefix: The prefix of the prompt of the sequence group. + """ + + def __init__( + self, + request_id: str, + is_prompt: bool, + seq_data: Dict[int, SequenceData], + sampling_params: SamplingParams, + block_tables: Dict[int, List[int]], + lora_request: Optional[LoRARequest] = None, + prefix: Optional[Prefix] = None, + state: Optional[SequenceGroupState] = None, + ) -> None: + self.request_id = request_id + self.is_prompt = is_prompt + self.seq_data = seq_data + self.sampling_params = sampling_params + self.block_tables = block_tables + self.lora_request = lora_request + self.prefix = prefix + self.state = SequenceGroupState() if state is None else state + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + + +class SequenceOutput: + """The model output associated with a sequence. + + Args: + parent_seq_id: The ID of the parent sequence (for forking in beam + search). + output_token: The output token ID. + logprobs: The logprobs of the output token. + (Token id -> logP(x_i+1 | x_0, ..., x_i)) + """ + + def __init__( + self, + parent_seq_id: int, + output_token: int, + logprobs: Dict[int, float], + ) -> None: + self.parent_seq_id = parent_seq_id + self.output_token = output_token + self.logprobs = logprobs + + def __repr__(self) -> str: + return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " + f"output_token={self.output_token}, " + f"logprobs={self.logprobs})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SequenceOutput): + raise NotImplementedError() + return (self.parent_seq_id == other.parent_seq_id + and self.output_token == other.output_token + and self.logprobs == other.logprobs) + + +class SequenceGroupOutput: + """The model output associated with a sequence group.""" + + def __init__( + self, + samples: List[SequenceOutput], + prompt_logprobs: Optional[PromptLogprobs], + ) -> None: + self.samples = samples + self.prompt_logprobs = prompt_logprobs + + def __repr__(self) -> str: + return (f"SequenceGroupOutput(samples={self.samples}, " + f"prompt_logprobs={self.prompt_logprobs})") + + def __eq__(self, other: object) -> bool: + if not isinstance(other, SequenceGroupOutput): + raise NotImplementedError() + return (self.samples == other.samples + and self.prompt_logprobs == other.prompt_logprobs) + + +# For each sequence group, we generate a list of SequenceOutput object, +# each of which contains one possible candidate for the next token. +SamplerOutput = List[SequenceGroupOutput] diff --git a/vllm/test_utils.py b/vllm/test_utils.py new file mode 100644 index 0000000..75bf6ce --- /dev/null +++ b/vllm/test_utils.py @@ -0,0 +1,41 @@ +import ray + +from vllm.config import ParallelConfig +from vllm.utils import get_open_port +from vllm.worker.worker import init_distributed_environment + + +def init_test_distributed_environment( + pipeline_parallel_size: int, + tensor_parallel_size: int, + rank: int, + distributed_init_port: str, +) -> None: + parallel_config = ParallelConfig(pipeline_parallel_size, + tensor_parallel_size, + worker_use_ray=True) + distributed_init_method = f"tcp://localhost:{distributed_init_port}" + init_distributed_environment( + parallel_config, + rank, + cupy_port=None, + distributed_init_method=distributed_init_method) + + +def multi_process_tensor_parallel( + tensor_parallel_size: int, + test_target, +) -> None: + # Using ray helps debugging the error when it failed + # as compared to multiprocessing. + ray.init() + + distributed_init_port = get_open_port() + refs = [] + for rank in range(tensor_parallel_size): + refs.append( + test_target.remote(tensor_parallel_size, rank, + distributed_init_port)) + ray.get(refs) + + ray.shutdown() diff --git a/vllm/transformers_utils/__init__.py b/vllm/transformers_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py new file mode 100644 index 0000000..4a3b3c1 --- /dev/null +++ b/vllm/transformers_utils/config.py @@ -0,0 +1,52 @@ +from typing import Optional + +from transformers import AutoConfig, PretrainedConfig + +from vllm.transformers_utils.configs import * + +_CONFIG_REGISTRY = { + "chatglm": ChatGLMConfig, + "mpt": MPTConfig, + "RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct) + "RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct) + "starcoder2": Starcoder2Config, + "minicpm": CPMDragonflyConfig, +} + + +def get_config(model: str, + trust_remote_code: bool, + revision: Optional[str] = None, + code_revision: Optional[str] = None) -> PretrainedConfig: + # FIXME(woosuk): This is a temporary fix for StarCoder2. + # Remove this when the model is supported by HuggingFace transformers. + if "bigcode" in model and "starcoder2" in model: + config_class = _CONFIG_REGISTRY["starcoder2"] + config = config_class.from_pretrained(model, + revision=revision, + code_revision=code_revision) + return config + + try: + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision) + except ValueError as e: + if (not trust_remote_code and + "requires you to execute the configuration file" in str(e)): + err_msg = ( + "Failed to load the model config. If the model is a custom " + "model not yet available in the HuggingFace transformers " + "library, consider setting `trust_remote_code=True` in LLM " + "or using the `--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + if config.model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[config.model_type] + config = config_class.from_pretrained(model, + revision=revision, + code_revision=code_revision) + return config diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py new file mode 100644 index 0000000..4935680 --- /dev/null +++ b/vllm/transformers_utils/configs/__init__.py @@ -0,0 +1,16 @@ +from vllm.transformers_utils.configs.chatglm import ChatGLMConfig +from vllm.transformers_utils.configs.mpt import MPTConfig +# RWConfig is for the original tiiuae/falcon-40b(-instruct) and +# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the +# `FalconConfig` class from the official HuggingFace transformers library. +from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.starcoder2 import Starcoder2Config +from vllm.transformers_utils.configs.cpm import CPMDragonflyConfig + +__all__ = [ + "ChatGLMConfig", + "MPTConfig", + "RWConfig", + "Starcoder2Config", + "CPMDragonflyConfig", +] diff --git a/vllm/transformers_utils/configs/chatglm.py b/vllm/transformers_utils/configs/chatglm.py new file mode 100644 index 0000000..c4244f8 --- /dev/null +++ b/vllm/transformers_utils/configs/chatglm.py @@ -0,0 +1,68 @@ +# coding=utf-8 +# Adapted from +# https://github.com/THUDM/ChatGLM2-6B +from transformers import PretrainedConfig + + +class ChatGLMConfig(PretrainedConfig): + model_type = "chatglm" + attribute_map = { + "num_hidden_layers": "num_layers", + "n_head_kv": "multi_query_group_num", + } + + def __init__(self, + num_layers=28, + padded_vocab_size=65024, + hidden_size=4096, + ffn_hidden_size=13696, + kv_channels=128, + num_attention_heads=32, + seq_length=2048, + hidden_dropout=0.0, + attention_dropout=0.0, + layernorm_epsilon=1e-5, + rmsnorm=True, + apply_residual_connection_post_layernorm=False, + post_layer_norm=True, + add_bias_linear=False, + add_qkv_bias=False, + interleaved_qkv=False, + bias_dropout_fusion=True, + multi_query_attention=False, + multi_query_group_num=1, + apply_query_key_layer_scaling=True, + attention_softmax_in_fp32=True, + fp32_residual_connection=False, + quantization_bit=0, + pre_seq_len=None, + prefix_projection=False, + **kwargs): + self.num_layers = num_layers + self.vocab_size = padded_vocab_size + self.padded_vocab_size = padded_vocab_size + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.kv_channels = kv_channels + self.num_attention_heads = num_attention_heads + self.seq_length = seq_length + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + self.layernorm_epsilon = layernorm_epsilon + self.rmsnorm = rmsnorm + self.apply_residual_connection_post_layernorm = ( + apply_residual_connection_post_layernorm) + self.post_layer_norm = post_layer_norm + self.add_bias_linear = add_bias_linear + self.add_qkv_bias = add_qkv_bias + self.bias_dropout_fusion = bias_dropout_fusion + self.multi_query_attention = multi_query_attention + self.multi_query_group_num = multi_query_group_num + self.apply_query_key_layer_scaling = apply_query_key_layer_scaling + self.attention_softmax_in_fp32 = attention_softmax_in_fp32 + self.fp32_residual_connection = fp32_residual_connection + self.quantization_bit = quantization_bit + self.pre_seq_len = pre_seq_len + self.prefix_projection = prefix_projection + self.interleaved_qkv = interleaved_qkv + super().__init__(**kwargs) diff --git a/vllm/transformers_utils/configs/cpm.py b/vllm/transformers_utils/configs/cpm.py new file mode 100644 index 0000000..da4711a --- /dev/null +++ b/vllm/transformers_utils/configs/cpm.py @@ -0,0 +1,113 @@ +# coding=utf-8 +# Copyright 2022 The OpenBMB team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import torch +import torch.nn.functional as F +from transformers.configuration_utils import PretrainedConfig +from typing_extensions import TypedDict + +import math + + +class CPMDragonflyConfig(PretrainedConfig): + model_type = "cpm_dragonfly" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "scale_emb": "scale_emb", + "scale_depth": "scale_depth", + "scale": "scale", + "attention_scale": "attention_scale", + "qk_norm": "qk_norm", + "ffn_gated": "ffn_gated", + } # model specific to common + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + num_attention_heads=32, + num_key_value_heads=32, + dim_head=128, + intermediate_size=11008, + num_hidden_layers=32, + dropout_p=0.0, + hidden_act="silu", + scale=True, + scale_emb: float=1., + scale_depth: float=-1, + dim_model_base:int=None, + rms_norm_eps=1e-5, + init_std=0.02, + half: bool = True, + half_type = 'bf16', + mask_modules: Optional[List[Tuple[bool, bool]]] = None, + use_flash_attn: bool = True, + flash_attn_mask_shape="1d", + flash_impl="cuda", + base=10000, + non_checkpointing_layers_num:int = 0, + attention_scale=1, + qk_norm=False, + ffn_gated=True, + tie_lm_head=False, + max_position_embeddings=2048, + **kwargs, + ): + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + self.dim_head = dim_head + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.max_position_embeddings = max_position_embeddings + self.dropout_p = dropout_p + self.hidden_act = hidden_act + self.scale = scale + self.scale_emb = scale_emb + self.half = half + self.half_type = half_type + self.dim_model_base = dim_model_base + self.scale_depth = scale_depth + self.rms_norm_eps = rms_norm_eps + self.init_std = init_std + self.flash_impl = flash_impl + self.mask_modules = mask_modules + self.use_flash_attn = use_flash_attn + self.flash_attn_mask_shape = flash_attn_mask_shape + self.base = base + self.attention_scale=attention_scale + self.qk_norm = qk_norm + self.ffn_gated = ffn_gated + self.non_checkpointing_layers_num = non_checkpointing_layers_num + self.tie_lm_head = tie_lm_head + self.use_bfloat16 = True if self.half_type == 'bf16' else False + super().__init__(architectures=["CPMDragonflyForCausalLM"]) + + @property + def scale_width(self,): + if self.scale: + return self.hidden_size / self.dim_model_base + else: + return 1. + + @property + def scale_states(self,): + if self.scale: + return self.scale_depth / math.sqrt(self.num_hidden_layers) + else: + return 1. \ No newline at end of file diff --git a/vllm/transformers_utils/configs/falcon.py b/vllm/transformers_utils/configs/falcon.py new file mode 100644 index 0000000..c82cc60 --- /dev/null +++ b/vllm/transformers_utils/configs/falcon.py @@ -0,0 +1,87 @@ +# Adapted from +# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py +# Copyright 2023 The vLLM team. +# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Falcon configuration""" +from transformers.configuration_utils import PretrainedConfig + + +class RWConfig(PretrainedConfig): + model_type = "falcon" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = { + "num_hidden_layers": "n_layer", + "num_attention_heads": "n_head", + "num_kv_heads": "n_head_kv", + } + + def __init__( + self, + vocab_size=250880, + hidden_size=64, + n_layer=2, + n_head=8, + layer_norm_epsilon=1e-5, + initializer_range=0.02, + use_cache=True, + bos_token_id=1, + eos_token_id=2, + hidden_dropout=0.0, + attention_dropout=0.0, + multi_query=True, + n_head_kv=None, + alibi=False, + bias=False, + parallel_attn=False, + new_decoder_architecture=False, + **kwargs, + ) -> None: + self.vocab_size = vocab_size + # Backward compatibility with n_embed kwarg + n_embed = kwargs.pop("n_embed", None) + self.hidden_size = hidden_size if n_embed is None else n_embed + self.n_layer = n_layer + self.n_head = n_head + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.hidden_dropout = hidden_dropout + self.attention_dropout = attention_dropout + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.multi_query = multi_query + self.n_head_kv = 1 if n_head_kv is None else n_head_kv + self.alibi = alibi + self.bias = bias + self.parallel_attn = parallel_attn + self.new_decoder_architecture = new_decoder_architecture + + if self.hidden_size == 8192: + # Hack for falcon-40b + self.new_decoder_architecture = True + + super().__init__(bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs) + + @property + def head_dim(self): + return self.hidden_size // self.n_head + + @property + def rotary(self): + return not self.alibi diff --git a/vllm/transformers_utils/configs/mpt.py b/vllm/transformers_utils/configs/mpt.py new file mode 100644 index 0000000..5ea0d91 --- /dev/null +++ b/vllm/transformers_utils/configs/mpt.py @@ -0,0 +1,232 @@ +# coding=utf-8 +# Copied from +# https://huggingface.co/mosaicml/mpt-7b/blob/main/configuration_mpt.py +"""A HuggingFace-style model configuration.""" +import warnings +from typing import Any, Dict, Optional, Union +from transformers import PretrainedConfig + +attn_config_defaults: Dict = { + 'attn_type': 'multihead_attention', + 'attn_pdrop': 0.0, + 'attn_impl': 'triton', + 'qk_ln': False, + 'clip_qkv': None, + 'softmax_scale': None, + 'prefix_lm': False, + 'attn_uses_sequence_id': False, + 'alibi': False, + 'alibi_bias_max': 8 +} +ffn_config_defaults: Dict = {'ffn_type': 'mptmlp'} +init_config_defaults: Dict = { + 'name': 'kaiming_normal_', + 'fan_mode': 'fan_in', + 'init_nonlinearity': 'relu', + 'init_div_is_residual': True, + 'emb_init_std': None, + 'emb_init_uniform_lim': None, + 'init_std': None, + 'init_gain': 0.0 +} + + +class MPTConfig(PretrainedConfig): + model_type = 'mpt' + attribute_map = { + 'num_attention_heads': 'n_heads', + 'hidden_size': 'd_model', + 'num_hidden_layers': 'n_layers', + } + + # pylint: disable=dangerous-default-value + def __init__(self, + d_model: int = 2048, + n_heads: int = 16, + n_layers: int = 24, + expansion_ratio: int = 4, + max_seq_len: int = 2048, + vocab_size: int = 50368, + resid_pdrop: float = 0.0, + emb_pdrop: float = 0.0, + learned_pos_emb: bool = True, + attn_config: Dict = attn_config_defaults, + ffn_config: Dict = ffn_config_defaults, + init_device: str = 'cpu', + logit_scale: Optional[Union[float, str]] = None, + no_bias: bool = False, + embedding_fraction: float = 1.0, + norm_type: str = 'low_precision_layernorm', + use_cache: bool = False, + init_config: Dict = init_config_defaults, + fc_type: str = 'torch', + verbose: Optional[int] = None, + **kwargs: Any): + """The MPT configuration class. + Args: + d_model (int): The size of the embedding dimension of the model. + n_heads (int): The number of attention heads. + n_layers (int): The number of layers in the model. + expansion_ratio (int): The ratio of the up/down scale in the ffn. + max_seq_len (int): The maximum sequence length of the model. + vocab_size (int): The size of the vocabulary. + resid_pdrop (float): The dropout probability applied to the attention output before combining with residual. + emb_pdrop (float): The dropout probability for the embedding layer. + learned_pos_emb (bool): Whether to use learned positional embeddings + attn_config (Dict): A dictionary used to configure the model's attention module: + attn_type (str): type of attention to use. Options: multihead_attention, multiquery_attention, grouped_query_attention + attn_pdrop (float): The dropout probability for the attention layers. + attn_impl (str): The attention implementation to use. One of 'torch', 'flash', or 'triton'. + qk_ln (bool): Whether to apply layer normalization to the queries and keys in the attention layer. + clip_qkv (Optional[float]): If not None, clip the queries, keys, and values in the attention layer to + this value. + softmax_scale (Optional[float]): If not None, scale the softmax in the attention layer by this value. If None, + use the default scale of ``1/sqrt(d_keys)``. + prefix_lm (Optional[bool]): Whether the model should operate as a Prefix LM. This requires passing an + extra `prefix_mask` argument which indicates which tokens belong to the prefix. Tokens in the prefix + can attend to one another bi-directionally. Tokens outside the prefix use causal attention. + attn_uses_sequence_id (Optional[bool]): Whether to restrict attention to tokens that have the same sequence_id. + When the model is in `train` mode, this requires passing an extra `sequence_id` argument which indicates + which sub-sequence each token belongs to. + Defaults to ``False`` meaning any provided `sequence_id` will be ignored. + alibi (bool): Whether to use the alibi bias instead of position embeddings. + alibi_bias_max (int): The maximum value of the alibi bias. + kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads. + ffn_config (Dict): A dictionary used to configure the model's ffn module: + ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp + init_device (str): The device to use for parameter initialization. + logit_scale (Optional[Union[float, str]]): If not None, scale the logits by this value. + no_bias (bool): Whether to use bias in all layers. + verbose (int): The verbosity level. 0 is silent. + embedding_fraction (float): The fraction to scale the gradients of the embedding layer by. + norm_type (str): choose type of norm to use + use_cache (bool): Whether or not the model should return the last key/values attentions + init_config (Dict): A dictionary used to configure the model initialization: + init_config.name: The parameter initialization scheme to use. Options: 'default_', 'baseline_', + 'kaiming_uniform_', 'kaiming_normal_', 'neox_init_', 'small_init_', 'xavier_uniform_', or + 'xavier_normal_'. These mimic the parameter initialization methods in PyTorch. + init_div_is_residual (Union[int, float, str, bool]): Value to divide initial weights by if ``module._is_residual`` is True. + emb_init_std (Optional[float]): The standard deviation of the normal distribution used to initialize the embedding layer. + emb_init_uniform_lim (Optional[Union[Tuple[float, float], float]]): The lower and upper limits of the uniform distribution + used to initialize the embedding layer. Mutually exclusive with ``emb_init_std``. + init_std (float): The standard deviation of the normal distribution used to initialize the model, + if using the baseline_ parameter initialization scheme. + init_gain (float): The gain to use for parameter initialization with kaiming or xavier initialization schemes. + fan_mode (str): The fan mode to use for parameter initialization with kaiming initialization schemes. + init_nonlinearity (str): The nonlinearity to use for parameter initialization with kaiming initialization schemes. + --- + See llmfoundry.models.utils.param_init_fns.py for info on other param init config options + fc_type (str): choose fc layer implementation. Options: torch and te. te layers support fp8 when using H100 GPUs. + """ + self.d_model = d_model + self.n_heads = n_heads + self.n_layers = n_layers + self.expansion_ratio = expansion_ratio + self.max_seq_len = max_seq_len + self.vocab_size = vocab_size + self.resid_pdrop = resid_pdrop + self.emb_pdrop = emb_pdrop + self.learned_pos_emb = learned_pos_emb + self.attn_config = attn_config + self.ffn_config = ffn_config + self.init_device = init_device + self.logit_scale = logit_scale + self.no_bias = no_bias + self.embedding_fraction = embedding_fraction + self.norm_type = norm_type + self.use_cache = use_cache + self.init_config = init_config + self.fc_type = fc_type + if verbose is not None: + warnings.warn(DeprecationWarning( + 'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.' + ), + stacklevel=2) + if 'name' in kwargs: + del kwargs['name'] + if 'loss_fn' in kwargs: + del kwargs['loss_fn'] + if self.attn_config.get('alibi', False): + self.learned_pos_emb = False + warnings.warn( + f'alibi is turned on, setting `learned_pos_emb` to {self.learned_pos_emb}`', + stacklevel=2) + super().__init__(**kwargs) + self._validate_config() + + def _set_config_defaults( + self, config: Dict[str, Any], + config_defaults: Dict[str, Any]) -> Dict[str, Any]: + for (k, v) in config_defaults.items(): + if k not in config: + config[k] = v + return config + + def _validate_config(self) -> None: + self.attn_config = self._set_config_defaults(self.attn_config, + attn_config_defaults) + self.ffn_config = self._set_config_defaults(self.ffn_config, + ffn_config_defaults) + self.init_config = self._set_config_defaults(self.init_config, + init_config_defaults) + if self.d_model % self.n_heads != 0: + raise ValueError('d_model must be divisible by n_heads') + if any(( + prob < 0 or prob > 1 for prob in + [self.attn_config['attn_pdrop'], self.resid_pdrop, self.emb_pdrop] + )): + raise ValueError( + "self.attn_config['attn_pdrop'], resid_pdrop, emb_pdrop are probabilities and must be between 0 and 1" # pylint: disable=line-too-long + ) + if self.attn_config['attn_impl'] not in ['torch', 'flash', 'triton']: + raise ValueError( + f"Unknown attn_impl={self.attn_config['attn_impl']}") + if self.attn_config['prefix_lm'] and self.attn_config[ + 'attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError( + 'prefix_lm only implemented with torch and triton attention.') + if self.attn_config['alibi'] and self.attn_config['attn_impl'] not in [ + 'torch', 'triton' + ]: + raise NotImplementedError( + 'alibi only implemented with torch and triton attention.') + if self.attn_config['attn_uses_sequence_id'] and self.attn_config[ + 'attn_impl'] not in ['torch', 'triton']: + raise NotImplementedError( + 'attn_uses_sequence_id only implemented with torch and triton attention.' # pylint: disable=line-too-long + ) + if self.embedding_fraction > 1 or self.embedding_fraction <= 0: + raise ValueError( + 'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!' # pylint: disable=line-too-long + ) + if isinstance(self.logit_scale, + str) and self.logit_scale != 'inv_sqrt_d_model': + raise ValueError( + f"self.logit_scale={self.logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'." # pylint: disable=line-too-long + ) + if self.init_config.get('name', None) is None: + raise ValueError( + f"self.init_config={self.init_config!r} 'name' needs to be set." + ) + if not self.learned_pos_emb and (not self.attn_config['alibi']): + warnings.warn( + 'Positional information not being provided to the model.', + stacklevel=2) + if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp': + try: + # pylint: disable=import-outside-toplevel + import transformer_engine.pytorch as te + del te + except Exception as exc: + raise ImportError( + # pylint: disable=line-too-long + 'TransformerEngine import fail. `fc_type: te` requires TransformerEngine be installed. ' + + + 'The required version of transformer_engine also requires FlashAttention v1.0.6 is installed:\n' + + 'pip install flash-attn==1.0.6 --no-build-isolation \n' + + 'pip install git+https://github.com/NVIDIA/TransformerEngine.git@144e4888b2cdd60bd52e706d5b7a79cb9c1a7156' + ) from exc + if self.ffn_config['ffn_type'] == 'mptmlp': + self.ffn_config['fc_type'] = self.fc_type + elif self.ffn_config['ffn_type'] == 'te_ln_mlp': + self.ffn_config['bias'] = not self.no_bias diff --git a/vllm/transformers_utils/configs/starcoder2.py b/vllm/transformers_utils/configs/starcoder2.py new file mode 100644 index 0000000..4c3b6b8 --- /dev/null +++ b/vllm/transformers_utils/configs/starcoder2.py @@ -0,0 +1,127 @@ +from transformers import PretrainedConfig + + +class Starcoder2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`Starcoder2Model`]. It is used to instantiate a + Starcoder2 model according to the specified arguments, defining the model architecture. Instantiating a configuration + with the defaults will yield a similar configuration to that of the [bigcode/starcoder2-7b_16k](https://huggingface.co/bigcode/starcoder2-7b_16k) model. + + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 49152): + Vocabulary size of the Starcoder2 model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Starcoder2Model`] + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 12288): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 30): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 24): + Number of attention heads for each attention layer in the Transformer encoder. + num_key_value_heads (`int`, *optional*, defaults to 2): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. Starcoder2's sliding window attention + allows sequence of up to 4096*32 tokens. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + norm_epsilon (`float`, *optional*, defaults to 1e-05): + Epsilon value for the layer norm + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + bos_token_id (`int`, *optional*, defaults to 50256): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 50256): + The id of the "end-of-sequence" token. + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + sliding_window (`int`, *optional*): + Sliding window attention window size. If not specified, will default to `None` (no sliding window). + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + residual_dropout (`float`, *optional*, defaults to 0.0): + Residual connection dropout value. + embedding_dropout (`float`, *optional*, defaults to 0.0): + Embedding dropout. + use_bias (`bool`, *optional*, defaults to `True`): + Whether to use bias term on linear layers of the model. + + + ```python + >>> from transformers import Starcoder2Model, Starcoder2Config + + >>> # Initializing a Starcoder2 7B style configuration + >>> configuration = Starcoder2Config() + + >>> # Initializing a model from the Starcoder2 7B style configuration + >>> model = Starcoder2Model(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "starcoder2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=49152, + hidden_size=3072, + intermediate_size=12288, + num_hidden_layers=30, + num_attention_heads=24, + num_key_value_heads=2, + hidden_act="gelu_pytorch_tanh", + max_position_embeddings=4096, + initializer_range=0.018042, + norm_epsilon=1e-5, + use_cache=True, + bos_token_id=50256, + eos_token_id=50256, + rope_theta=10000.0, + sliding_window=None, + attention_dropout=0.0, + residual_dropout=0.0, + embedding_dropout=0.0, + use_bias=True, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.use_bias = use_bias + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.norm_epsilon = norm_epsilon + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + self.residual_dropout = residual_dropout + self.embedding_dropout = embedding_dropout + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + **kwargs, + ) + if self.architectures is None: + self.architectures = ['Starcoder2ForCausalLM'] diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py new file mode 100644 index 0000000..6edc225 --- /dev/null +++ b/vllm/transformers_utils/tokenizer.py @@ -0,0 +1,245 @@ +from typing import List, Optional, Tuple, Union + +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) + +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.utils import make_async, LRUCache +from vllm.transformers_utils.tokenizers import * + +logger = init_logger(__name__) + + +def get_tokenizer( + tokenizer_name: str, + *args, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + tokenizer_revision: Optional[str] = None, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Gets a tokenizer for the given model name via Huggingface.""" + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError( + "Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + tokenizer_revision=tokenizer_revision, + **kwargs) + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + if (not trust_remote_code and + ("does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e))): + err_msg = ( + "Failed to load the tokenizer. If the tokenizer is a custom " + "tokenizer not yet available in the HuggingFace transformers " + "library, consider setting `trust_remote_code=True` in LLM " + "or using the `--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + except AttributeError as e: + if "BaichuanTokenizer" in str(e): + # This is for the error "'BaichuanTokenizer' object has no + # attribute 'sp_model'". + tokenizer = BaichuanTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + tokenizer_revision=tokenizer_revision, + **kwargs) + else: + raise e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.warning( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead.") + return tokenizer + + +def get_lora_tokenizer(lora_request: LoRARequest, *args, + **kwargs) -> Optional[PreTrainedTokenizer]: + if lora_request is None: + return None + try: + tokenizer = get_tokenizer(lora_request.lora_local_path, *args, + **kwargs) + except OSError as e: + # No tokenizer was found in the LoRA folder, + # use base model tokenizer + logger.warning( + f"No tokenizer found in {lora_request.lora_local_path}, " + "using base model tokenizer instead. " + f"(Exception: {str(e)})") + tokenizer = None + return tokenizer + + +get_lora_tokenizer_async = make_async(get_lora_tokenizer) + + +class TokenizerGroup: + """A group of tokenizers that can be used for LoRA adapters.""" + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], **tokenizer_config): + self.tokenizer_id = tokenizer_id + self.tokenizer_config = tokenizer_config + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + if enable_lora: + self.lora_tokenizers = LRUCache(capacity=max_num_seqs) + else: + self.lora_tokenizers = None + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) + return tokenizer.encode(prompt) + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = await self.get_lora_tokenizer_async(lora_request) + return tokenizer.encode(prompt) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (get_lora_tokenizer( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (await get_lora_tokenizer_async( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + +def _convert_tokens_to_string_with_added_encoders( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + output_tokens: List[str], + skip_special_tokens: bool, + spaces_between_special_tokens: bool, +) -> str: + # Adapted from + # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + # NOTE(woosuk): The following code is slow because it runs a for loop over + # the output_tokens. In Python, running a for loop over a list can be slow + # even when the loop body is very simple. + sub_texts = [] + current_sub_text = [] + all_special_tokens = set(tokenizer.all_special_tokens) + for token in output_tokens: + if skip_special_tokens and token in all_special_tokens: + continue + if token in tokenizer.get_added_vocab(): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + if spaces_between_special_tokens: + return " ".join(sub_texts) + else: + return "".join(sub_texts) + + +# Based on +# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 +# under Apache 2.0 license +def detokenize_incrementally( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + all_input_ids: List[int], + prev_tokens: Optional[List[str]], + prefix_offset: int = 0, + read_offset: int = 0, + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = True, +) -> Tuple[List[str], str, int, int]: + new_token_id = all_input_ids[-1] + # This is the first iteration for this sequence + if prev_tokens is None: + new_tokens = tokenizer.convert_ids_to_tokens( + all_input_ids, skip_special_tokens=skip_special_tokens) + output_tokens = new_tokens + # 5 is an arbitrary value that should work for all + # tokenizers (bigger = more conservative). + # Subtract 1 extra to account for the generated token. + prefix_offset = max(len(output_tokens) - 6, 0) + # If the first new token is a special token, we can't skip 1 extra token + if skip_special_tokens and new_token_id in tokenizer.all_special_ids: + read_offset = max(len(output_tokens), 0) + else: + read_offset = max(len(output_tokens) - 1, 0) + else: + # Put new_token_id in a list so skip_special_tokens is respected + new_tokens = tokenizer.convert_ids_to_tokens( + [new_token_id], skip_special_tokens=skip_special_tokens) + output_tokens = prev_tokens + new_tokens + + # The prefix text is necessary only to defeat cleanup algorithms in + # the decode which decide to add a space or not depending on the + # surrounding ids. + if tokenizer.is_fast or not tokenizer.get_added_vocab(): + prefix_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:read_offset]) + new_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:]) + else: + prefix_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + new_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + + if len(new_text) > len(prefix_text) and not new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + new_text = new_text[len(prefix_text):] + return new_tokens, new_text, read_offset, len(output_tokens) + else: + return new_tokens, "", prefix_offset, read_offset diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py new file mode 100644 index 0000000..e6b5972 --- /dev/null +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -0,0 +1,5 @@ +from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer + +__all__ = [ + "BaichuanTokenizer", +] diff --git a/vllm/transformers_utils/tokenizers/baichuan.py b/vllm/transformers_utils/tokenizers/baichuan.py new file mode 100644 index 0000000..1dd241e --- /dev/null +++ b/vllm/transformers_utils/tokenizers/baichuan.py @@ -0,0 +1,263 @@ +# yapf: disable +# Adapted from +# https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/blob/8f6e343d545c503b91429582231d1d354dac2740/tokenization_baichuan.py +# This includes a fix suggested in +# https://github.com/vllm-project/vllm/issues/1403#issuecomment-1767503058 +# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. + +import os +from shutil import copyfile +from typing import Any, Dict, List, Optional, Tuple + +import sentencepiece as spm +from transformers.tokenization_utils import AddedToken, PreTrainedTokenizer +from transformers.utils import logging + + +logger = logging.get_logger(__name__) + +VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"} + +PRETRAINED_VOCAB_FILES_MAP = { + "vocab_file": {}, + "tokenizer_file": {}, +} +PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {} + + +class BaichuanTokenizer(PreTrainedTokenizer): + """ + Construct a Baichuan tokenizer. Based on byte-level Byte-Pair-Encoding. + + Args: + vocab_file (`str`): + Path to the vocabulary file. + """ + + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP + max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES + model_input_names = ["input_ids", "attention_mask"] + + def __init__( + self, + vocab_file, + unk_token="", + bos_token="", + eos_token="", + pad_token=None, + sp_model_kwargs: Optional[Dict[str, Any]] = None, + add_bos_token=True, + add_eos_token=False, + clean_up_tokenization_spaces=False, + **kwargs, + ): + self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs + bos_token = ( + AddedToken(bos_token, lstrip=False, rstrip=False) + if isinstance(bos_token, str) + else bos_token + ) + eos_token = ( + AddedToken(eos_token, lstrip=False, rstrip=False) + if isinstance(eos_token, str) + else eos_token + ) + unk_token = ( + AddedToken(unk_token, lstrip=False, rstrip=False) + if isinstance(unk_token, str) + else unk_token + ) + pad_token = ( + AddedToken(pad_token, lstrip=False, rstrip=False) + if isinstance(pad_token, str) + else pad_token + ) + self.vocab_file = vocab_file + self.add_bos_token = add_bos_token + self.add_eos_token = add_eos_token + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(vocab_file) + super().__init__( + bos_token=bos_token, + eos_token=eos_token, + unk_token=unk_token, + pad_token=pad_token, + add_bos_token=add_bos_token, + add_eos_token=add_eos_token, + sp_model_kwargs=self.sp_model_kwargs, + clean_up_tokenization_spaces=clean_up_tokenization_spaces, + **kwargs, + ) + + def __getstate__(self): + state = self.__dict__.copy() + state["sp_model"] = None + return state + + def __setstate__(self, d): + self.__dict__ = d + self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs) + self.sp_model.Load(self.vocab_file) + + @property + def vocab_size(self): + """Returns vocab size""" + return self.sp_model.get_piece_size() + + def get_vocab(self): + """Returns vocab as a dict""" + vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)} + vocab.update(self.added_tokens_encoder) + return vocab + + def _tokenize(self, text): + """Returns a tokenized string.""" + return self.sp_model.encode(text, out_type=str) + + def _convert_token_to_id(self, token): + """Converts a token (str) in an id using the vocab.""" + return self.sp_model.piece_to_id(token) + + def _convert_id_to_token(self, index): + """Converts an index (integer) in a token (str) using the vocab.""" + token = self.sp_model.IdToPiece(index) + return token + + def convert_tokens_to_string(self, tokens): + """Converts a sequence of tokens (string) in a single string.""" + current_sub_tokens = [] + out_string = "" + prev_is_special = False + for i, token in enumerate(tokens): + # make sure that special tokens are not decoded using sentencepiece model + if token in self.all_special_tokens: + if not prev_is_special and i != 0: + out_string += " " + out_string += self.sp_model.decode(current_sub_tokens) + token + prev_is_special = True + current_sub_tokens = [] + else: + current_sub_tokens.append(token) + prev_is_special = False + out_string += self.sp_model.decode(current_sub_tokens) + return out_string + + def save_vocabulary( + self, save_directory, filename_prefix: Optional[str] = None + ) -> Tuple[str]: + """ + Save the vocabulary and special tokens file to a directory. + + Args: + save_directory (`str`): + The directory in which to save the vocabulary. + + Returns: + `Tuple(str)`: Paths to the files saved. + """ + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + out_vocab_file = os.path.join( + save_directory, + (filename_prefix + "-" if filename_prefix else "") + + VOCAB_FILES_NAMES["vocab_file"], + ) + + if os.path.abspath(self.vocab_file) != os.path.abspath( + out_vocab_file + ) and os.path.isfile(self.vocab_file): + copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) + + return (out_vocab_file,) + + def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = bos_token_id + token_ids_0 + eos_token_id + + if token_ids_1 is not None: + output = output + bos_token_id + token_ids_1 + eos_token_id + + return output + + def get_special_tokens_mask( + self, + token_ids_0: List[int], + token_ids_1: Optional[List[int]] = None, + already_has_special_tokens: bool = False, + ) -> List[int]: + """ + Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding + special tokens using the tokenizer `prepare_for_model` method. + + Args: + token_ids_0 (`List[int]`): + List of IDs. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + already_has_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not the token list is already formatted with special tokens for the model. + + Returns: + `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. + """ + if already_has_special_tokens: + return super().get_special_tokens_mask( + token_ids_0=token_ids_0, + token_ids_1=token_ids_1, + already_has_special_tokens=True, + ) + + bos_token_id = [1] if self.add_bos_token else [] + eos_token_id = [1] if self.add_eos_token else [] + + if token_ids_1 is None: + return bos_token_id + ([0] * len(token_ids_0)) + eos_token_id + return ( + bos_token_id + + ([0] * len(token_ids_0)) + + eos_token_id + + bos_token_id + + ([0] * len(token_ids_1)) + + eos_token_id + ) + + def create_token_type_ids_from_sequences( + self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None + ) -> List[int]: + """ + Creates a mask from the two sequences passed to be used in a sequence-pair classification task. An ALBERT + sequence pair mask has the following format: + + ``` + 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 + | first sequence | second sequence | + ``` + + if token_ids_1 is None, only returns the first portion of the mask (0s). + + Args: + token_ids_0 (`List[int]`): + List of ids. + token_ids_1 (`List[int]`, *optional*): + Optional second list of IDs for sequence pairs. + + Returns: + `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s). + """ + bos_token_id = [self.bos_token_id] if self.add_bos_token else [] + eos_token_id = [self.eos_token_id] if self.add_eos_token else [] + + output = [0] * len(bos_token_id + token_ids_0 + eos_token_id) + + if token_ids_1 is not None: + output += [1] * len(bos_token_id + token_ids_1 + eos_token_id) + + return output diff --git a/vllm/utils.py b/vllm/utils.py new file mode 100644 index 0000000..a4f9bfe --- /dev/null +++ b/vllm/utils.py @@ -0,0 +1,311 @@ +import enum +import os +import socket +import subprocess +import uuid +from platform import uname +from typing import List, Tuple, Union +from packaging.version import parse, Version + +import psutil +import torch +import asyncio +from functools import partial +from typing import ( + Awaitable, + Callable, + TypeVar, +) +from collections import OrderedDict +from typing import Any, Hashable, Optional + +from vllm.logger import init_logger + +T = TypeVar("T") +logger = init_logger(__name__) + +STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8_e5m2": torch.uint8, +} + + +class Device(enum.Enum): + GPU = enum.auto() + CPU = enum.auto() + + +class Counter: + + def __init__(self, start: int = 0) -> None: + self.counter = start + + def __next__(self) -> int: + i = self.counter + self.counter += 1 + return i + + def reset(self) -> None: + self.counter = 0 + + +class LRUCache: + + def __init__(self, capacity: int): + self.cache = OrderedDict() + self.capacity = capacity + + def __contains__(self, key: Hashable) -> bool: + return key in self.cache + + def __len__(self) -> int: + return len(self.cache) + + def __getitem__(self, key: Hashable) -> Any: + return self.get(key) + + def __setitem__(self, key: Hashable, value: Any) -> None: + self.put(key, value) + + def __delitem__(self, key: Hashable) -> None: + self.pop(key) + + def touch(self, key: Hashable) -> None: + self.cache.move_to_end(key) + + def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: + if key in self.cache: + value = self.cache[key] + self.cache.move_to_end(key) + else: + value = default_value + return value + + def put(self, key: Hashable, value: Any) -> None: + self.cache[key] = value + self.cache.move_to_end(key) + self._remove_old_if_needed() + + def _on_remove(self, key: Hashable, value: Any): + pass + + def remove_oldest(self): + if not self.cache: + return + key, value = self.cache.popitem(last=False) + self._on_remove(key, value) + + def _remove_old_if_needed(self) -> None: + while len(self.cache) > self.capacity: + self.remove_oldest() + + def pop(self, key: int, default_value: Optional[Any] = None) -> Any: + run_on_remove = key in self.cache + value = self.cache.pop(key, default_value) + if run_on_remove: + self._on_remove(key, value) + return value + + def clear(self): + while len(self.cache) > 0: + self.remove_oldest() + self.cache.clear() + + +def is_hip() -> bool: + return torch.version.hip is not None + + +def is_neuron() -> bool: + try: + import transformers_neuronx + except ImportError: + transformers_neuronx = None + return transformers_neuronx is not None + + +def get_max_shared_memory_bytes(gpu: int = 0) -> int: + """Returns the maximum shared memory per thread block in bytes.""" + # NOTE: This import statement should be executed lazily since + # the Neuron-X backend does not have the `cuda_utils` module. + from vllm._C import cuda_utils + + max_shared_mem = cuda_utils.get_max_shared_memory_per_block_device_attribute( + gpu) + # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py will fail + assert max_shared_mem > 0, "max_shared_mem can not be zero" + return int(max_shared_mem) + + +def get_cpu_memory() -> int: + """Returns the total CPU memory of the node in bytes.""" + return psutil.virtual_memory().total + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +def in_wsl() -> bool: + # Reference: https://github.com/microsoft/WSL/issues/4071 + return "microsoft" in " ".join(uname()).lower() + + +def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: + """Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args, **kwargs) -> asyncio.Future: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=None, func=p_func) + + return _async_wrapper + + +def get_ip() -> str: + # try ipv4 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("dns.google", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except OSError: + # try ipv6 + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + s.connect(("dns.google", 80)) + return s.getsockname()[0] + + +def get_distributed_init_method(ip: str, port: int) -> str: + return f"tcp://{ip}:{port}" + + +def get_open_port() -> int: + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def set_cuda_visible_devices(device_ids: List[int]) -> None: + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) + + +def get_nvcc_cuda_version() -> Optional[Version]: + cuda_home = os.environ.get('CUDA_HOME') + if not cuda_home: + cuda_home = '/usr/local/cuda' + if os.path.isfile(cuda_home + '/bin/nvcc'): + logger.info( + f'CUDA_HOME is not found in the environment. Using {cuda_home} as CUDA_HOME.' + ) + else: + logger.warning( + f'Not found nvcc in {cuda_home}. Skip cuda version check!') + return None + nvcc_output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], + universal_newlines=True) + output = nvcc_output.split() + release_idx = output.index("release") + 1 + nvcc_cuda_version = parse(output[release_idx].split(",")[0]) + return nvcc_cuda_version + + +def _generate_random_fp8_e5m2( + tensor: torch.tensor, + low: float, + high: float, +) -> None: + # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, + # it may occur Inf or NaN if we directly use torch.randint + # to generate random data for fp8 data. + # For example, s.11111.00 in fp8e5m2 format represents Inf. + # | E4M3 | E5M2 + #-----|-------------|------------------- + # Inf | N/A | s.11111.00 + # NaN | s.1111.111 | s.11111.{01,10,11} + from vllm._C import cache_ops + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) + tensor_tmp.uniform_(low, high) + cache_ops.convert_fp8_e5m2(tensor_tmp, tensor) + del tensor_tmp + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: Optional[int] = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str): + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in ["half", "bfloat16", "float"]: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + elif cache_dtype == "fp8_e5m2": + torch_dtype = torch.uint8 + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=torch_dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype == 'fp8_e5m2': + _generate_random_fp8_e5m2(key_cache, -scale, scale) + elif torch_dtype in [torch.half, torch.bfloat16, torch.float]: + key_cache.uniform_(-scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype == 'fp8_e5m2': + _generate_random_fp8_e5m2(value_cache, -scale, scale) + elif torch_dtype in [torch.half, torch.bfloat16, torch.float]: + value_cache.uniform_(-scale, scale) + else: + raise ValueError( + f"Does not support value cache of type {cache_dtype}") + value_caches.append(value_cache) + return key_caches, value_caches diff --git a/vllm/worker/__init__.py b/vllm/worker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py new file mode 100644 index 0000000..dbeb730 --- /dev/null +++ b/vllm/worker/cache_engine.py @@ -0,0 +1,215 @@ +import os +enable_infer_paged_attn = os.getenv("ENABLE_INFER_PAGED_ATTN",None) +"""CacheEngine class for managing the KV cache.""" +from typing import Dict, List, Tuple + +import torch + +from vllm.config import CacheConfig, ModelConfig, ParallelConfig +from vllm.logger import init_logger +from vllm.utils import in_wsl, is_neuron, STR_DTYPE_TO_TORCH_DTYPE + +logger = init_logger(__name__) + +KVCache = Tuple[torch.Tensor, torch.Tensor] + + +class CacheEngine: + """Manages the KV cache. + + This class is responsible for initializing and managing the GPU and CPU KV + caches. It also provides methods for performing KV cache operations, such + as swapping and copying. + """ + + def __init__( + self, + cache_config: CacheConfig, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> None: + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + + self.head_size = model_config.get_head_size() + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + self.num_gpu_blocks = cache_config.num_gpu_blocks + self.num_cpu_blocks = cache_config.num_cpu_blocks + + # Skip initializing CUDA stream and buffer for Neuron backend. + if is_neuron(): + return + + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # Initialize the cache. + self.gpu_cache = self.allocate_gpu_cache() + self.cpu_cache = self.allocate_cpu_cache() + + # Initialize the stream for caching operations. + self.cache_stream = torch.cuda.Stream() + assert self.cache_stream != torch.cuda.current_stream() + # Initialize the events for stream synchronization. + self.events = [torch.cuda.Event() for _ in range(self.num_layers)] + + def get_key_block_shape(self) -> Tuple[int, int, int, int]: + element_size = torch.tensor([], dtype=self.dtype).element_size() + x = 16 // element_size + use_v2 = self.head_size == 128 and self.block_size == 16 and enable_infer_paged_attn is None + if use_v2: + return ( + self.num_heads, + self.block_size, + self.head_size, + ) + else: + return ( + self.num_heads, + self.head_size // x, + self.block_size, + x, + ) + + def get_value_block_shape(self) -> Tuple[int, int, int]: + use_v2 = self.head_size == 128 and self.block_size == 16 and enable_infer_paged_attn is None + if use_v2: + return ( + self.num_heads, + self.block_size, + self.head_size, + ) + else: + return ( + self.num_heads, + self.head_size, + self.block_size, + ) + + # TODO align + """ + def get_key_block_shape(self) -> Tuple[int, int, int, int]: + element_size = torch.tensor([], dtype=self.dtype).element_size() + x = 16 // element_size + return ( + self.num_heads, + self.head_size // x, + self.block_size, + x, + ) + + def get_value_block_shape(self) -> Tuple[int, int, int]: + return ( + self.num_heads, + self.head_size, + self.block_size, + ) + """ + def allocate_gpu_cache(self) -> List[KVCache]: + gpu_cache: List[KVCache] = [] + key_block_shape = self.get_key_block_shape() + value_block_shape = self.get_value_block_shape() + for _ in range(self.num_layers): + key_blocks = torch.zeros( + size=(self.num_gpu_blocks, *key_block_shape), + dtype=self.dtype, + device="cuda", + ) + value_blocks = torch.zeros( + size=(self.num_gpu_blocks, *value_block_shape), + dtype=self.dtype, + device="cuda", + ) + gpu_cache.append((key_blocks, value_blocks)) + return gpu_cache + + def allocate_cpu_cache(self) -> List[KVCache]: + cpu_cache: List[KVCache] = [] + key_block_shape = self.get_key_block_shape() + value_block_shape = self.get_value_block_shape() + pin_memory = not in_wsl() + if not pin_memory: + # Pinning memory in WSL is not supported. + # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications + logger.warning("Using 'pin_memory=False' as WSL is detected. " + "This may slow down the performance.") + for _ in range(self.num_layers): + key_blocks = torch.zeros( + size=(self.num_cpu_blocks, *key_block_shape), + dtype=self.dtype, + pin_memory=pin_memory, + device="cpu", + ) + value_blocks = torch.zeros( + size=(self.num_cpu_blocks, *value_block_shape), + dtype=self.dtype, + pin_memory=pin_memory, + device="cpu", + ) + cpu_cache.append((key_blocks, value_blocks)) + return cpu_cache + + def _swap( + self, + src: List[KVCache], + dst: List[KVCache], + src_to_dst: Dict[int, int], + ) -> None: + from vllm._C import cache_ops + + with torch.cuda.stream(self.cache_stream): + for i in range(self.num_layers): + src_key_cache, src_value_cache = src[i] + dst_key_cache, dst_value_cache = dst[i] + # Copy the key blocks. + cache_ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + # Copy the value blocks. + cache_ops.swap_blocks(src_value_cache, dst_value_cache, + src_to_dst) + event = self.events[i] + event.record(stream=self.cache_stream) + + def swap_in(self, src_to_dst: Dict[int, int]) -> None: + self._swap(self.cpu_cache, self.gpu_cache, src_to_dst) + + def swap_out(self, src_to_dst: Dict[int, int]) -> None: + self._swap(self.gpu_cache, self.cpu_cache, src_to_dst) + + def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + from vllm._C import cache_ops + + key_caches = [key_cache for key_cache, _ in self.gpu_cache] + value_caches = [value_cache for _, value_cache in self.gpu_cache] + # NOTE(woosuk): This operation implicitly synchronizes the CPU and GPU. + cache_ops.copy_blocks(key_caches, value_caches, src_to_dsts) + + @staticmethod + def get_cache_block_size( + block_size: int, + cache_dtype: str, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + key_cache_block = block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + if cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype_size = _get_dtype_size(dtype) + return dtype_size * total + + +def _get_dtype_size(dtype: torch.dtype) -> int: + return torch.tensor([], dtype=dtype).element_size() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py new file mode 100644 index 0000000..3885993 --- /dev/null +++ b/vllm/worker/model_runner.py @@ -0,0 +1,1223 @@ +import contextlib +import time +from typing import Dict, List, Optional, Tuple, Set, Union + +import numpy as np +import torch +import torch.nn as nn + +from vllm.config import (DeviceConfig, ModelConfig, LoRAConfig, ParallelConfig, + SchedulerConfig) +from vllm.logger import init_logger +from vllm.model_executor import get_model, InputMetadata, SamplingMetadata +from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_tensor_dict) +from vllm.model_executor.parallel_utils.parallel_state import ( + with_cupy_nccl_for_all_reduce) +from vllm.model_executor.parallel_utils import custom_all_reduce +from vllm.sampling_params import SamplingParams, SamplingType +from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.utils import in_wsl + +logger = init_logger(__name__) + +KVCache = Tuple[torch.Tensor, torch.Tensor] +_PAD_SLOT_ID = -1 +LORA_WARMUP_RANK = 8 +# Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +# NOTE: _get_graph_batch_size needs to be updated if this list is changed. +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] + + +class ModelRunner: + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + ): + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.lora_config = lora_config + self.is_driver_worker = False + # TODO align + """ + self.is_driver_worker = is_driver_worker + """ + + # model_config can be None in tests/samplers/test_sampler.py. + # FIXME(woosuk): This is a hack to make the tests work. Refactor this. + self.sliding_window = (model_config.get_sliding_window() + if model_config is not None else None) + self.device_config = (device_config + if device_config is not None else DeviceConfig()) + self.device = self.device_config.device + + self.model = None + self.block_size = None # Set after initial profiling. + self.lora_manager = None + + self.graph_runners: Dict[int, CUDAGraphRunner] = {} + self.graph_memory_pool = None # Set during graph capture. + + self.max_context_len_to_capture = ( + self.model_config.max_context_len_to_capture + if self.model_config is not None else 0) + # When using CUDA graph, the input block tables must be padded to + # max_context_len_to_capture. However, creating the block table in + # Python can be expensive. To optimize this, we cache the block table + # in numpy and only copy the actual input content at every iteration. + # The shape of the cached block table will be + # (max batch size to capture, max context len to capture / block size). + self.graph_block_tables = None # Set after initial profiling. + # cache in_wsl result + self.in_wsl = in_wsl() + self.kv_cache_dtype = kv_cache_dtype + + # Set enforce_eager to True for Neuron backend, to avoid capturing graph + if self.device_config.is_neuron: + self.model_config.enforce_eager = True + + def load_model(self) -> None: + self.model = get_model(self.model_config, + self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + vocab_size = self.model.config.vocab_size + + if self.lora_config: + assert hasattr( + self.model, "supported_lora_modules" + ) and self.model.supported_lora_modules, "Model does not support LoRA" + assert hasattr( + self.model, + "embedding_modules"), "Model does not have embedding_modules" + assert hasattr(self.model, "embedding_padding_modules" + ), "Model does not have embedding_padding_modules" + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens + + self.scheduler_config.max_paddings, vocab_size, + self.lora_config, self.device, self.model.embedding_modules, + self.model.embedding_padding_modules) + self.model = self.lora_manager.create_lora_manager(self.model) + + def set_block_size(self, block_size: int) -> None: + self.block_size = block_size + + max_num_blocks = (self.max_context_len_to_capture + block_size - + 1) // block_size + self.graph_block_tables = np.zeros( + (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], + List[int], List[int], Set[LoRARequest]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[List[int]] = [] + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] + lora_requests: Set[LoRARequest] = set() + + prompt_lens: List[int] = [] + context_lens: List[int] = [] + subquery_lens: List[int] = [] + prefix_block_tables: List[List[int]] = [] + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + prefix_len = 0 + prefix = seq_group_metadata.prefix + if prefix is not None and prefix.computed: + prefix_len = prefix.get_length() + prompt_tokens = prompt_tokens[prefix_len:] + prefix_block_tables.append(prefix.get_block_numbers()) + else: + prefix_block_tables.append([]) + # actual prompt lens + context_lens.append(prefix_len) + subquery_lens.append(prompt_len - prefix_len) + + input_tokens.extend(prompt_tokens) + input_positions.extend(list(range(prefix_len, prefix_len + len(prompt_tokens)))) + + lora_id = seq_group_metadata.lora_int_id + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping.append([lora_id] * (prompt_len - prefix_len)) + lora_prompt_mapping.extend( + [lora_id] * + (prompt_len - prefix_len + if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.append([_PAD_SLOT_ID] * prompt_len) + continue + + # Compute the slot mapping. + slot_mapping.append([]) + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, prompt_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + assert prefix_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") + start_idx = max(0, prompt_len - self.sliding_window) + for i in range(prompt_len): + if i < start_idx: + slot_mapping[-1].append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping[-1].append(slot) + + max_prompt_len = max(subquery_lens) + slot_mappings = [] + lora_index_mappints = [] + for mapping in slot_mapping: + slot_mappings.extend(mapping) + for mapping in lora_index_mapping: + lora_index_mappints.extend(mapping) + + input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) + input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) + slot_mapping = torch.tensor(slot_mappings, dtype=torch.int, device=self.device) + + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + # Prepare prefix block tables + max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) + block_tables = _make_tensor_with_pad( + prefix_block_tables, + max_len=max_prompt_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + + prompt_lens_tensor = torch.tensor(prompt_lens, + dtype=torch.int, + device=self.device) + start_loc_tensor = prompt_lens_tensor.cumsum(dim=-1) + + input_metadata = InputMetadata( + prompt_lens=prompt_lens, + is_prompt=True, + slot_mapping=slot_mapping, + max_seq_len=max_prompt_len, + start_loc=start_loc_tensor, + max_context_len=None, + context_lens=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + kv_cache_dtype=self.kv_cache_dtype, + ) + return (input_tokens, input_positions, input_metadata, prompt_lens, + subquery_lens, lora_index_mapping, lora_prompt_mapping, + lora_requests) + + # TODO align + """ + def _prepare_prompt( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], + List[int], List[int], Set[LoRARequest]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] + lora_requests: Set[LoRARequest] = set() + + prompt_lens: List[int] = [] + context_lens: List[int] = [] + subquery_lens: List[int] = [] + prefix_block_tables: List[List[int]] = [] + for seq_group_metadata in seq_group_metadata_list: + assert seq_group_metadata.is_prompt + seq_ids = list(seq_group_metadata.seq_data.keys()) + assert len(seq_ids) == 1 + seq_id = seq_ids[0] + + seq_data = seq_group_metadata.seq_data[seq_id] + prompt_tokens = seq_data.get_token_ids() + prompt_len = len(prompt_tokens) + prompt_lens.append(prompt_len) + prefix_len = 0 + prefix = seq_group_metadata.prefix + if prefix is not None and prefix.computed: + prefix_len = prefix.get_length() + prompt_tokens = prompt_tokens[prefix_len:] + prefix_block_tables.append(prefix.get_block_numbers()) + else: + prefix_block_tables.append([]) + # actual prompt lens + context_lens.append(prefix_len) + subquery_lens.append(prompt_len - prefix_len) + + input_tokens.append(prompt_tokens) + # NOTE(woosuk): Here we assume that the first token in the prompt + # is always the first token in the sequence. + input_positions.append( + list(range(prefix_len, prefix_len + len(prompt_tokens)))) + + lora_id = seq_group_metadata.lora_int_id + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping.append([lora_id] * (prompt_len - prefix_len)) + lora_prompt_mapping.extend( + [lora_id] * + (prompt_len - prefix_len + if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + + if seq_group_metadata.block_tables is None: + # During memory profiling, the block tables are not initialized + # yet. In this case, we just use a dummy slot mapping. + slot_mapping.append([_PAD_SLOT_ID] * prompt_len) + continue + + # Compute the slot mapping. + slot_mapping.append([]) + block_table = seq_group_metadata.block_tables[seq_id] + # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, + # where start_idx is max(0, prompt_len - sliding_window). + # For example, if the prompt len is 10, sliding window is 8, and + # block size is 4, the first two tokens are masked and the slot + # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. + start_idx = 0 + if self.sliding_window is not None: + assert prefix_len == 0, ( + "Prefix caching is currently not supported with " + "sliding window attention") + start_idx = max(0, prompt_len - self.sliding_window) + for i in range(prefix_len, prompt_len): + if i < start_idx: + slot_mapping[-1].append(_PAD_SLOT_ID) + continue + + block_number = block_table[i // self.block_size] + block_offset = i % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping[-1].append(slot) + + max_prompt_len = max(subquery_lens) + input_tokens = _make_tensor_with_pad(input_tokens, + max_prompt_len, + pad=0, + dtype=torch.long, + device=self.device) + input_positions = _make_tensor_with_pad(input_positions, + max_prompt_len, + pad=0, + dtype=torch.long, + device=self.device) + slot_mapping = _make_tensor_with_pad(slot_mapping, + max_prompt_len, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) + lora_index_mapping = [ + _pad_to_max(mapping, max_prompt_len, pad=0) + for mapping in lora_index_mapping + ] + context_lens_tensor = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + # Prepare prefix block tables + max_prompt_block_table_len = max(len(t) for t in prefix_block_tables) + block_tables = _make_tensor_with_pad( + prefix_block_tables, + max_len=max_prompt_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + start_loc_tensor = torch.arange(0, + len(prompt_lens) * max_prompt_len, + max_prompt_len, + dtype=torch.long, + device=self.device) + prompt_lens_tensor = torch.tensor(prompt_lens, + dtype=torch.long, + device=self.device) + + input_metadata = InputMetadata( + is_prompt=True, + slot_mapping=slot_mapping, + prompt_lens=prompt_lens_tensor, + max_seq_len=max_prompt_len, + start_loc=start_loc_tensor, + max_context_len=None, + context_lens=context_lens_tensor, + block_tables=block_tables, + use_cuda_graph=False, + kv_cache_dtype=self.kv_cache_dtype, + ) + return (input_tokens, input_positions, input_metadata, prompt_lens, + subquery_lens, lora_index_mapping, lora_prompt_mapping, + lora_requests) + """ + + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[int] = [] + input_positions: List[int] = [] + slot_mapping: List[int] = [] + context_lens: List[int] = [] + block_tables: List[List[int]] = [] + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] + lora_requests: Set[LoRARequest] = set() + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + + seq_ids = list(seq_group_metadata.seq_data.keys()) + lora_id = seq_group_metadata.lora_int_id + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.extend([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.extend([position]) + + context_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + context_lens.append(context_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.extend([slot]) + lora_index_mapping.append([lora_id]) + lora_prompt_mapping.append(lora_id) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + batch_size = len(input_tokens) + max_context_len = max(context_lens) + + use_captured_graph = ( + not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_context_len <= self.max_context_len_to_capture) + + if use_captured_graph: + # Pad the input tokens, positions, and slot mapping to match the + # batch size of the captured graph. + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + input_tokens.extend([0] * (graph_batch_size - batch_size)) + input_positions.extend([0] * (graph_batch_size - batch_size)) + slot_mapping.extend([_PAD_SLOT_ID] * (graph_batch_size - batch_size)) + context_lens.extend([1] * (graph_batch_size - batch_size)) + for _ in range(graph_batch_size - batch_size): + block_tables.append([]) + batch_size = graph_batch_size + + # When using CUDA graph, we don't need to make the tensors on the GPU + # because they will be eventually copied to the designated GPU buffer. + input_tokens = torch.tensor(input_tokens, dtype=torch.long, device=self.device) + input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) + slot_mapping = torch.tensor(slot_mapping, dtype=torch.int, device=self.device) + context_lens = torch.tensor(context_lens,dtype=torch.int,device=self.device) + + if use_captured_graph: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.graph_block_tables[:batch_size] + for i, block_table in enumerate(block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=self.device) + else: + max_block_table_len = max([len(t) for t in block_tables]) + block_tables = _make_tensor_with_pad(block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device) + + lora_index_mapping = [ + _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping + ] + + input_metadata = InputMetadata( + prompt_lens=[], + is_prompt=False, + slot_mapping=slot_mapping, + max_seq_len=None, + start_loc=None, + max_context_len=max_context_len, + context_lens=context_lens, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + kv_cache_dtype=self.kv_cache_dtype, + ) + return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests + + # TODO align + """ + def _prepare_decode( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], + Set[LoRARequest]]: + assert len(seq_group_metadata_list) > 0 + input_tokens: List[List[int]] = [] + input_positions: List[List[int]] = [] + slot_mapping: List[List[int]] = [] + context_lens: List[int] = [] + block_tables: List[List[int]] = [] + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] + lora_requests: Set[LoRARequest] = set() + + for seq_group_metadata in seq_group_metadata_list: + assert not seq_group_metadata.is_prompt + + seq_ids = list(seq_group_metadata.seq_data.keys()) + lora_id = seq_group_metadata.lora_int_id + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + for seq_id in seq_ids: + seq_data = seq_group_metadata.seq_data[seq_id] + generation_token = seq_data.get_last_token_id() + input_tokens.append([generation_token]) + + seq_len = seq_data.get_len() + position = seq_len - 1 + input_positions.append([position]) + + context_len = seq_len if self.sliding_window is None else min( + seq_len, self.sliding_window) + context_lens.append(context_len) + + block_table = seq_group_metadata.block_tables[seq_id] + block_number = block_table[position // self.block_size] + block_offset = position % self.block_size + slot = block_number * self.block_size + block_offset + slot_mapping.append([slot]) + lora_index_mapping.append([lora_id]) + lora_prompt_mapping.append(lora_id) + + if self.sliding_window is not None: + sliding_window_blocks = (self.sliding_window // + self.block_size) + block_table = block_table[-sliding_window_blocks:] + block_tables.append(block_table) + + batch_size = len(input_tokens) + max_context_len = max(context_lens) + use_captured_graph = ( + not self.model_config.enforce_eager + and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and max_context_len <= self.max_context_len_to_capture) + if use_captured_graph: + # Pad the input tokens, positions, and slot mapping to match the + # batch size of the captured graph. + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + for _ in range(graph_batch_size - batch_size): + input_tokens.append([]) + input_positions.append([]) + slot_mapping.append([]) + context_lens.append(1) + block_tables.append([]) + batch_size = graph_batch_size + + input_tokens = _make_tensor_with_pad(input_tokens, + max_len=1, + pad=0, + dtype=torch.long, + device=self.device) + input_positions = _make_tensor_with_pad(input_positions, + max_len=1, + pad=0, + dtype=torch.long, + device=self.device) + slot_mapping = _make_tensor_with_pad(slot_mapping, + max_len=1, + pad=_PAD_SLOT_ID, + dtype=torch.long, + device=self.device) + context_lens = torch.tensor(context_lens, + dtype=torch.int, + device=self.device) + + if use_captured_graph: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + input_block_tables = self.graph_block_tables[:batch_size] + for i, block_table in enumerate(block_tables): + if block_table: + input_block_tables[i, :len(block_table)] = block_table + block_tables = torch.tensor(input_block_tables, device=self.device) + else: + max_block_table_len = max( + len(block_table) for block_table in block_tables) + block_tables = _make_tensor_with_pad( + block_tables, + max_len=max_block_table_len, + pad=0, + dtype=torch.int, + device=self.device, + ) + + lora_index_mapping = [ + _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping + ] + + input_metadata = InputMetadata( + is_prompt=False, + slot_mapping=slot_mapping, + prompt_lens=None, + max_seq_len=None, + start_loc=None, + max_context_len=max_context_len, + context_lens=context_lens, + block_tables=block_tables, + use_cuda_graph=use_captured_graph, + kv_cache_dtype=self.kv_cache_dtype, + ) + return (input_tokens, input_positions, input_metadata, + lora_index_mapping, lora_prompt_mapping, lora_requests) + """ + + def _prepare_sample( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + prompt_lens: List[int], + subquery_lens: Optional[List[int]], + ) -> SamplingMetadata: + seq_groups: List[Tuple[List[int], SamplingParams]] = [] + selected_token_indices: List[int] = [] + generators: List[torch.Generator] = [] + selected_token_start_idx = 0 + categorized_sample_indices = {t: [] for t in SamplingType} + categorized_sample_indices_start_idx = 0 + pin_memory = not self.in_wsl and not self.device_config.is_neuron + + max_subquery_len = max(subquery_lens) if subquery_lens else 1 + for i, seq_group_metadata in enumerate(seq_group_metadata_list): + seq_ids = list(seq_group_metadata.seq_data.keys()) + sampling_params = seq_group_metadata.sampling_params + seq_groups.append((seq_ids, sampling_params)) + + if seq_group_metadata.is_prompt: + assert len(seq_ids) == 1 + assert subquery_lens is not None + subquery_len = subquery_lens[i] + if sampling_params.prompt_logprobs is not None: + # NOTE: prompt token positions do not need sample, skip + categorized_sample_indices_start_idx += subquery_len - 1 + + categorized_sample_indices[ + sampling_params.sampling_type].append( + categorized_sample_indices_start_idx) + categorized_sample_indices_start_idx += 1 + + if sampling_params.prompt_logprobs is not None: + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + subquery_len - 1)) + selected_token_indices.append(selected_token_start_idx + + subquery_len - 1) + selected_token_start_idx += subquery_len + # TODO align + """ + selected_token_start_idx += max_subquery_len + """ + if sampling_params.seed is not None: + seq_group_metadata.state.generator = torch.Generator( + device="cuda").manual_seed(sampling_params.seed) + else: + num_seqs = len(seq_ids) + selected_token_indices.extend( + range(selected_token_start_idx, + selected_token_start_idx + num_seqs)) + selected_token_start_idx += num_seqs + + categorized_sample_indices[ + sampling_params.sampling_type].extend( + range(categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + num_seqs)) + categorized_sample_indices_start_idx += num_seqs + + if sampling_params.seed is not None: + generators.append(seq_group_metadata.state.generator) + + selected_token_indices = _async_h2d(selected_token_indices, + dtype=torch.long, + target_device=self.device, + pin_memory=pin_memory) + categorized_sample_indices = { + t: _async_h2d(seq_ids, + dtype=torch.long, + target_device=self.device, + pin_memory=pin_memory) + for t, seq_ids in categorized_sample_indices.items() + } + + # TODO align + """ + categorized_sample_indices = { + t: _async_h2d(seq_ids, + dtype=torch.int, + target_device=self.device, + pin_memory=pin_memory) + for t, seq_ids in categorized_sample_indices.items() + } + """ + + seq_data: Dict[int, SequenceData] = {} + for seq_group_metadata in seq_group_metadata_list: + seq_data.update(seq_group_metadata.seq_data) + + sampling_metadata = SamplingMetadata( + seq_groups=seq_groups, + seq_data=seq_data, + prompt_lens=prompt_lens, + selected_token_indices=selected_token_indices, + categorized_sample_indices=categorized_sample_indices, + generators=generators, + ) + return sampling_metadata + + # TODO align + """ + def prepare_input_tensors( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, + Set[int], LoRAMapping]: + if self.is_driver_worker: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, input_metadata, prompt_lens, + subquery_lens, lora_index_mapping, lora_prompt_mapping, + lora_requests) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, input_metadata, + lora_index_mapping, lora_prompt_mapping, + lora_requests) = self._prepare_decode(seq_group_metadata_list) + prompt_lens = [] + subquery_lens = None + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + prompt_lens, + subquery_lens) + + if self.lora_config: + flat_lora_index_mapping = [ + item for sublist in lora_index_mapping for item in sublist + ] + lora_mapping = LoRAMapping( + flat_lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + # Broadcast the metadata. + metadata_dict = { + "input_tokens": input_tokens, + "input_positions": input_positions, + "is_prompt": input_metadata.is_prompt, + "slot_mapping": input_metadata.slot_mapping, + "prompt_lens": input_metadata.prompt_lens, + "max_seq_len": input_metadata.max_seq_len, + "start_loc": input_metadata.start_loc, + "max_context_len": input_metadata.max_context_len, + "context_lens": input_metadata.context_lens, + "block_tables": input_metadata.block_tables, + "use_cuda_graph": input_metadata.use_cuda_graph, + "kv_cache_dtype": input_metadata.kv_cache_dtype, + "selected_token_indices": + sampling_metadata.selected_token_indices, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, + } + broadcast_tensor_dict(metadata_dict, src=0) + else: + metadata_dict = broadcast_tensor_dict(src=0) + input_tokens = metadata_dict["input_tokens"] + input_positions = metadata_dict["input_positions"] + lora_mapping = metadata_dict["lora_mapping"] + lora_requests = metadata_dict["lora_requests"] + input_metadata = InputMetadata( + is_prompt=metadata_dict["is_prompt"], + slot_mapping=metadata_dict["slot_mapping"], + prompt_lens=metadata_dict["prompt_lens"], + max_seq_len=metadata_dict["max_seq_len"], + start_loc=metadata_dict["start_loc"], + max_context_len=metadata_dict["max_context_len"], + context_lens=metadata_dict["context_lens"], + block_tables=metadata_dict["block_tables"], + use_cuda_graph=metadata_dict["use_cuda_graph"], + kv_cache_dtype=metadata_dict["kv_cache_dtype"], + ) + sampling_metadata = SamplingMetadata( + seq_groups=None, + seq_data=None, + prompt_lens=None, + selected_token_indices=metadata_dict["selected_token_indices"], + categorized_sample_indices=None, + generators=None, + perform_sampling=False, + ) + + return (input_tokens, input_positions, input_metadata, + sampling_metadata, lora_requests, lora_mapping) + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> Optional[SamplerOutput]: + (input_tokens, input_positions, input_metadata, sampling_metadata, + lora_requests, + lora_mapping) = self.prepare_input_tensors(seq_group_metadata_list) + + if self.lora_config: + self.set_active_loras(lora_requests, lora_mapping) + + # Execute the model. + if input_metadata.use_cuda_graph: + graph_batch_size = input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + hidden_states = model_executable( + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + input_metadata=input_metadata, + ) + + # Sample the next token. + output = self.model.sample( + hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + ) + return output + """ + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + ) -> SamplerOutput: + # NOTE: We assume that all sequences in the group are all prompts or + # all decodes. + is_prompt = seq_group_metadata_list[0].is_prompt + # Prepare input tensors. + if is_prompt: + (input_tokens, input_positions, input_metadata, prompt_lens, + subquery_lens, lora_index_mapping, lora_prompt_mapping, + lora_requests) = self._prepare_prompt(seq_group_metadata_list) + else: + (input_tokens, input_positions, input_metadata, + lora_index_mapping, lora_prompt_mapping, + lora_requests) = self._prepare_decode(seq_group_metadata_list) + prompt_lens = [] + subquery_lens = None + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + prompt_lens, + subquery_lens) + + if self.lora_config: + flat_lora_index_mapping = [ + item for sublist in lora_index_mapping for item in sublist + ] + lora_mapping = LoRAMapping( + flat_lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + + if self.lora_config: + self.set_active_loras(lora_requests, lora_mapping) + + # Execute the model. + if input_metadata.use_cuda_graph: + graph_batch_size = input_tokens.shape[0] + model_executable = self.graph_runners[graph_batch_size] + else: + model_executable = self.model + hidden_states = model_executable( + input_ids=input_tokens, + positions=input_positions, + kv_caches=kv_caches, + input_metadata=input_metadata, + ) + + sampling_metadata = self._prepare_sample(seq_group_metadata_list, + prompt_lens, + subquery_lens) + + # Sample the next token. + output = self.model.sample( + hidden_states=hidden_states, + sampling_metadata=sampling_metadata, + ) + return output + + @torch.inference_mode() + def profile_run(self) -> None: + # Enable top-k sampling to reflect the accurate memory usage. + vocab_size = self.model_config.get_vocab_size() + sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1) + max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests = [] + dummy_lora_requests_per_seq = [] + if self.lora_config: + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + + # Profile memory usage with max_num_sequences sequences and the total + # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + seq_data = SequenceData([0] * seq_len) + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, + ) + seqs.append(seq) + + # Run the model with the dummy inputs. + num_layers = self.model_config.get_num_layers(self.parallel_config) + kv_caches = [(None, None)] * num_layers + self.execute_model(seqs, kv_caches) + torch.cuda.synchronize() + return + + def remove_all_loras(self) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_all_loras() + + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_loras(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_loras() + + @torch.inference_mode() + def capture_model(self, kv_caches: List[KVCache]) -> None: + # NOTE(woosuk): This is a hack to ensure that the NCCL backend is never + # deleted before the CUDA graphs. + self.cupy_nccl_backend = cupy_utils.get_nccl_backend() + + assert not self.model_config.enforce_eager + logger.info("Capturing the model for CUDA graphs. This may lead to " + "unexpected consequences if the model is not static. To " + "run the model in eager mode, set 'enforce_eager=True' or " + "use '--enforce-eager' in the CLI.") + logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " + "If you are running out of memory, consider decreasing " + "`gpu_memory_utilization` or enforcing eager mode. " + "You can also reduce the `max_num_seqs` as needed " + "to decrease memory usage.") + start_time = time.perf_counter() + + # Prepare dummy inputs. These will be reused for all batch sizes. + max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + input_tokens = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() + input_positions = torch.zeros(max_batch_size, 1, + dtype=torch.long).cuda() + slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() + slot_mapping.fill_(_PAD_SLOT_ID) + context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() + block_tables = torch.from_numpy(self.graph_block_tables).cuda() + + graph_batch_size = _get_graph_batch_size( + self.scheduler_config.max_num_seqs) + batch_size_capture_list = [ + bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size + ] + + # NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce + # kernel, CuPy NCCL, and PyTorch NCCL. When using CUDA graph, we use + # either custom all-reduce kernel or CuPy NCCL. When not using CUDA + # graph, we use either custom all-reduce kernel or PyTorch NCCL. + # We always prioritize using custom all-reduce kernel but fall back + # to PyTorch or CuPy NCCL if it is disabled or not supported. + with custom_all_reduce.capture(): + # NOTE: Capturing the largest batch size first may help reduce the + # memory usage of CUDA graph. + for batch_size in reversed(batch_size_capture_list): + # Create dummy input_metadata. + input_metadata = InputMetadata( + is_prompt=False, + slot_mapping=slot_mapping[:batch_size], + prompt_lens=None, + max_seq_len=None, + start_loc=None, + max_context_len=self.max_context_len_to_capture, + context_lens=context_lens[:batch_size], + block_tables=block_tables[:batch_size], + use_cuda_graph=True, + kv_cache_dtype=self.kv_cache_dtype, + ) + + if self.lora_config: + lora_mapping = LoRAMapping( + [0] * batch_size, + [0] * batch_size, + ) + self.set_active_loras(set(), lora_mapping) + + graph_runner = CUDAGraphRunner(self.model) + graph_runner.capture( + input_tokens[:batch_size], + input_positions[:batch_size], + kv_caches, + input_metadata, + memory_pool=self.graph_memory_pool, + ) + self.graph_memory_pool = graph_runner.graph.pool() + self.graph_runners[batch_size] = graph_runner + + end_time = time.perf_counter() + elapsed_time = end_time - start_time + # This usually takes < 10 seconds. + logger.info(f"Graph capturing finished in {elapsed_time:.0f} secs.") + + def __del__(self) -> None: + # Delete the CUDA graphs before deleting the CuPy NCCL communicator. + # NOTE(woosuk): This is necessary because otherwise deadlocks can + # happen. + # FIXME(woosuk): This is a bit hacky. Find a more robust solution. + self.graph_runners.clear() + self.cupy_nccl_backend = None + + +class CUDAGraphRunner: + + def __init__(self, model: nn.Module): + self.model = model + self.graph = None + self.input_buffers: Dict[str, torch.Tensor] = {} + self.output_buffers: Dict[str, torch.Tensor] = {} + + def capture( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[KVCache], + input_metadata: InputMetadata, + memory_pool, + ) -> None: + assert self.graph is None + # Run the model once without capturing the graph. + # This is to make sure that the captured graph does not include the + # kernel launches for initial benchmarking (e.g., Triton autotune). + with _maybe_cupy_nccl(): + self.model( + input_ids, + positions, + kv_caches, + input_metadata, + ) + torch.cuda.synchronize() + + # Capture the graph. + # NOTE(woosuk): Python 3.8 does not support multi-line with statements. + # https://stackoverflow.com/questions/31039022/python-multi-line-with-statement + self.graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(self.graph, pool=memory_pool): # noqa: SIM117 + with _maybe_cupy_nccl(): + hidden_states = self.model( + input_ids, + positions, + kv_caches, + input_metadata, + ) + torch.cuda.synchronize() + + # Save the input and output buffers. + self.input_buffers = { + "input_ids": input_ids, + "positions": positions, + "kv_caches": kv_caches, + "slot_mapping": input_metadata.slot_mapping, + "context_lens": input_metadata.context_lens, + "block_tables": input_metadata.block_tables, + } + self.output_buffers = {"hidden_states": hidden_states} + return + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], + input_metadata: InputMetadata, + ) -> torch.Tensor: + # KV caches are fixed tensors, so we don't need to copy them. + del kv_caches + + # Copy the input tensors to the input buffers. + self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True) + self.input_buffers["positions"].copy_(positions, non_blocking=True) + self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping, + non_blocking=True) + self.input_buffers["context_lens"].copy_(input_metadata.context_lens, + non_blocking=True) + self.input_buffers["block_tables"].copy_(input_metadata.block_tables, + non_blocking=True) + + # Run the graph. + self.graph.replay() + + # Return the output tensor. + return self.output_buffers["hidden_states"] + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + +@contextlib.contextmanager +def _maybe_cupy_nccl(): + if cupy_utils.is_initialized() and not custom_all_reduce.is_initialized(): + with with_cupy_nccl_for_all_reduce(): + yield + else: + yield + + +def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: + assert len(x) <= max_len + return x + [pad] * (max_len - len(x)) + + +def _make_tensor_with_pad( + x: List[List[int]], + max_len: int, + pad: int, + dtype: torch.dtype, + device: Optional[Union[str, torch.device]], +) -> torch.Tensor: + padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] + return torch.tensor(padded_x, dtype=dtype, device=device) + + +def _get_graph_batch_size(batch_size: int) -> int: + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + else: + return (batch_size + 7) // 8 * 8 + + +def _async_h2d( + data: list, + dtype: torch.dtype, + target_device: Union[str, torch.device], + pin_memory: bool, +) -> torch.Tensor: + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") + return t.to(device=target_device, non_blocking=True) diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py new file mode 100644 index 0000000..3229a21 --- /dev/null +++ b/vllm/worker/neuron_worker.py @@ -0,0 +1,191 @@ +"""A Neuron worker class.""" +from typing import Dict, List, Optional, Tuple + +import torch +import torch.distributed + +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) +from vllm.model_executor import set_random_seed +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_tensor_dict) +from vllm.model_executor.parallel_utils.parallel_state import ( + ensure_model_parallel_initialized) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.model_runner import ModelRunner + + +class Worker: + """A worker class that executes the model on a group of neuron cores. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + self.model_runner = ModelRunner(model_config, + parallel_config, + scheduler_config, + device_config, + lora_config=self.lora_config, + is_driver_worker=is_driver_worker) + # Uninitialized cache engine. Will be initialized by + # self.init_cache_engine(). + self.cache_config = None + self.cache_engine = None + self.cache_events = None + self.gpu_cache = None + + def init_model(self) -> None: + # Initialize the distributed environment. + _init_distributed_environment(self.parallel_config, + self.rank, + self.distributed_init_method, + distributed_backend="gloo") + + # Initialize the model. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + @torch.inference_mode() + def profile_num_available_blocks( + self, + block_size: int = 128, + gpu_memory_utilization: float = 0.9, + cpu_swap_space: int = 0, + cache_dtype: str = "float16", + ) -> Tuple[int, int]: + """Simply returns max_num_seqs as num_gpu_blocks, 0 as num_cpu_blocks.""" + num_gpu_blocks = self.scheduler_config.max_num_seqs + num_cpu_blocks = 0 + return num_gpu_blocks, num_cpu_blocks + + def init_cache_engine(self, cache_config: CacheConfig) -> None: + self.cache_config = cache_config + self.cache_engine = CacheEngine(self.cache_config, self.model_config, + self.parallel_config) + self.model_runner.set_block_size(self.cache_engine.block_size) + + def warm_up_model(self) -> None: + # Warm up is maintained in transformers-neuronx + pass + + def cache_swap( + self, + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> None: + # Issue cache operations. + issued_cache_op = False + if blocks_to_swap_in: + self.cache_engine.swap_in(blocks_to_swap_in) + issued_cache_op = True + if blocks_to_swap_out: + self.cache_engine.swap_out(blocks_to_swap_out) + issued_cache_op = True + if blocks_to_copy: + self.cache_engine.copy(blocks_to_copy) + issued_cache_op = True + + cache_events = self.cache_events if issued_cache_op else None + + # Wait for cache operations to finish. + if cache_events is not None: + raise NotImplementedError( + "cache operations are not implemented for neuron backend.") + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, List[int]]] = None, + ) -> Optional[SamplerOutput]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + num_seq_groups = len(seq_group_metadata_list) + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None + data = { + "num_seq_groups": num_seq_groups, + "blocks_to_swap_in": blocks_to_swap_in, + "blocks_to_swap_out": blocks_to_swap_out, + "blocks_to_copy": blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) + else: + data = broadcast_tensor_dict(src=0) + num_seq_groups = data["num_seq_groups"] + blocks_to_swap_in = data["blocks_to_swap_in"] + blocks_to_swap_out = data["blocks_to_swap_out"] + blocks_to_copy = data["blocks_to_copy"] + + self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return {} + + output = self.model_runner.execute_model(seq_group_metadata_list, + self.gpu_cache) + return output + + +def _init_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + distributed_init_method: Optional[str] = None, + distributed_backend: Optional[str] = None, +) -> None: + """Initialize the distributed environment.""" + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch world " + "size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + distributed_backend = distributed_backend if distributed_backend else "nccl" + torch.distributed.init_process_group( + backend=distributed_backend, + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1)) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py new file mode 100644 index 0000000..bdf3a25 --- /dev/null +++ b/vllm/worker/worker.py @@ -0,0 +1,354 @@ +"""A GPU worker class.""" +import gc +import os +from typing import Dict, List, Tuple, Set, Optional + +import torch +import torch.distributed + +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) +from vllm.model_executor import set_random_seed +from vllm.model_executor.parallel_utils import cupy_utils +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_tensor_dict) +from vllm.model_executor.parallel_utils.custom_all_reduce import init_custom_ar +from vllm.model_executor.parallel_utils.parallel_state import ( + ensure_model_parallel_initialized) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.worker.cache_engine import CacheEngine +from vllm.worker.model_runner import ModelRunner +from vllm.lora.request import LoRARequest +from vllm.utils import is_hip + + +class Worker: + """A worker class that executes (a partition of) the model on a GPU. + + Each worker is associated with a single GPU. The worker is responsible for + maintaining the KV cache and executing the model on the GPU. In case of + distributed inference, each worker is assigned a partition of the model. + """ + + def __init__( + self, + model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, + kv_cache_dtype: Optional[str] = "auto", + is_driver_worker: bool = False, + ) -> None: + self.model_config = model_config + self.parallel_config = parallel_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + self.lora_config = lora_config + self.is_driver_worker = is_driver_worker + if self.is_driver_worker: + assert self.rank == 0, "The driver worker must have rank 0." + + self.model_runner = ModelRunner(model_config, + parallel_config, + scheduler_config, + device_config, + lora_config=self.lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=False) + # TODO align + """ + self.model_runner = ModelRunner(model_config, + parallel_config, + scheduler_config, + device_config, + lora_config=self.lora_config, + kv_cache_dtype=kv_cache_dtype, + is_driver_worker=is_driver_worker) + """ + # Uninitialized cache engine. Will be initialized by + # self.init_cache_engine(). + self.cache_config = None + self.cache_engine = None + self.cache_events = None + self.gpu_cache = None + + def init_model(self, cupy_port: Optional[int] = None) -> None: + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device(f"cuda:{self.local_rank}") + torch.cuda.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. + init_distributed_environment(self.parallel_config, self.rank, + cupy_port, self.distributed_init_method) + # Initialize the model. + set_random_seed(self.model_config.seed) + + def load_model(self): + self.model_runner.load_model() + + @torch.inference_mode() + def profile_num_available_blocks( + self, + block_size: int, + gpu_memory_utilization: float, + cpu_swap_space: int, + cache_dtype: str, + ) -> Tuple[int, int]: + """Profiles the peak memory usage of the model and returns the maximum + number of GPU and CPU cache blocks that can be allocated. + + Args: + block_size: The size of the cache block. + gpu_memory_utilization: The fraction of the total GPU memory to use. + cpu_swap_space: The size of the CPU swap space in bytes. + """ + # Profile the memory usage of the model and get the maximum number of + # cache blocks that can be allocated with the remaining free memory. + torch.cuda.empty_cache() + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + self.model_runner.profile_run() + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. + torch.cuda.synchronize() + free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + # NOTE(woosuk): Here we assume that the other processes using the same + # GPU did not change their memory usage during the profiling. + peak_memory = self.init_gpu_memory - free_gpu_memory + + cache_block_size = CacheEngine.get_cache_block_size( + block_size, cache_dtype, self.model_config, self.parallel_config) + num_gpu_blocks = int( + (total_gpu_memory * gpu_memory_utilization - peak_memory) // + cache_block_size) + num_cpu_blocks = int(cpu_swap_space // cache_block_size) + num_gpu_blocks = max(num_gpu_blocks, 0) + num_cpu_blocks = max(num_cpu_blocks, 0) + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + gc.collect() + torch.cuda.empty_cache() + return num_gpu_blocks, num_cpu_blocks + + def init_cache_engine(self, cache_config: CacheConfig) -> None: + self.cache_config = cache_config + self.cache_engine = CacheEngine(self.cache_config, self.model_config, + self.parallel_config) + self.cache_events = self.cache_engine.events + self.gpu_cache = self.cache_engine.gpu_cache + self.model_runner.set_block_size(self.cache_engine.block_size) + + def warm_up_model(self) -> None: + if not self.model_config.enforce_eager: + self.model_runner.capture_model(self.gpu_cache) + # Reset the seed to ensure that the random state is not affected by + # the model initialization and profiling. + set_random_seed(self.model_config.seed) + + def cache_swap( + self, + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> None: + # Issue cache operations. + issued_cache_op = False + if blocks_to_swap_in: + self.cache_engine.swap_in(blocks_to_swap_in) + issued_cache_op = True + if blocks_to_swap_out: + self.cache_engine.swap_out(blocks_to_swap_out) + issued_cache_op = True + if blocks_to_copy: + self.cache_engine.copy(blocks_to_copy) + issued_cache_op = True + + cache_events = self.cache_events if issued_cache_op else None + + # Wait for cache operations to finish. + # TODO(woosuk): Profile swapping overhead and optimize if needed. + if cache_events is not None: + for event in cache_events: + event.wait() + + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, List[int]]] = None, + ) -> Optional[SamplerOutput]: + # Issue cache operations. + issued_cache_op = False + if blocks_to_swap_in: + self.cache_engine.swap_in(blocks_to_swap_in) + issued_cache_op = True + if blocks_to_swap_out: + self.cache_engine.swap_out(blocks_to_swap_out) + issued_cache_op = True + if blocks_to_copy: + self.cache_engine.copy(blocks_to_copy) + issued_cache_op = True + + cache_events = self.cache_events if issued_cache_op else None + + # Wait for cache operations to finish. + # TODO(woosuk): Profile swapping overhead and optimize if needed. + if cache_events is not None: + for event in cache_events: + event.wait() + # If there is no input, we don't need to execute the model. + if not seq_group_metadata_list: + return {} + + output = self.model_runner.execute_model(seq_group_metadata_list, + self.gpu_cache) + return output + + # TODO align + """ + @torch.inference_mode() + def execute_model( + self, + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None, + blocks_to_swap_in: Optional[Dict[int, int]] = None, + blocks_to_swap_out: Optional[Dict[int, int]] = None, + blocks_to_copy: Optional[Dict[int, List[int]]] = None, + ) -> Optional[SamplerOutput]: + if self.is_driver_worker: + assert seq_group_metadata_list is not None + num_seq_groups = len(seq_group_metadata_list) + assert blocks_to_swap_in is not None + assert blocks_to_swap_out is not None + assert blocks_to_copy is not None + data = { + "num_seq_groups": num_seq_groups, + "blocks_to_swap_in": blocks_to_swap_in, + "blocks_to_swap_out": blocks_to_swap_out, + "blocks_to_copy": blocks_to_copy, + } + broadcast_tensor_dict(data, src=0) + else: + data = broadcast_tensor_dict(src=0) + num_seq_groups = data["num_seq_groups"] + blocks_to_swap_in = data["blocks_to_swap_in"] + blocks_to_swap_out = data["blocks_to_swap_out"] + blocks_to_copy = data["blocks_to_copy"] + + self.cache_swap(blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy) + + # If there is no input, we don't need to execute the model. + if num_seq_groups == 0: + return {} + + output = self.model_runner.execute_model(seq_group_metadata_list, + self.gpu_cache) + return output + """ + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + + +def init_distributed_environment( + parallel_config: ParallelConfig, + rank: int, + cupy_port: Optional[int], + distributed_init_method: Optional[str] = None, +) -> None: + """Initialize the distributed environment.""" + if torch.distributed.is_initialized(): + torch_world_size = torch.distributed.get_world_size() + if torch_world_size != parallel_config.world_size: + raise RuntimeError( + "torch.distributed is already initialized but the torch world " + "size does not match parallel_config.world_size " + f"({torch_world_size} vs. {parallel_config.world_size}).") + elif not distributed_init_method: + raise ValueError( + "distributed_init_method must be set if torch.distributed " + "is not already initialized") + else: + torch.distributed.init_process_group( + backend="nccl", + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) + + if cupy_utils.is_initialized(): + cupy_world_size = cupy_utils.get_world_size() + if cupy_world_size != parallel_config.world_size: + raise RuntimeError( + "cupy.distributed is already initialized but the cupy world " + "size does not match parallel_config.world_size " + f"({cupy_world_size} vs. {parallel_config.world_size}).") + elif (parallel_config.world_size > 1 and cupy_port is not None + and not is_hip()): + # NOTE(woosuk): We don't initialize CuPy process group when world size + # is 1. + # TODO(woosuk): Support multi-node connection. + cupy_utils.init_process_group( + world_size=parallel_config.world_size, + rank=rank, + host="localhost", + port=cupy_port, + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cuda()) + if cupy_utils.is_initialized(): + cupy_utils.all_reduce(torch.zeros(1).cuda()) + ensure_model_parallel_initialized(parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + # Initialize a custom fast all-reduce implementation. + if not parallel_config.disable_custom_all_reduce: + init_custom_ar() + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: + return # avoid capability error + compute_capability = torch.cuda.get_device_capability() + if compute_capability[0] < 8: + gpu_name = torch.cuda.get_device_name() + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU has compute capability " + f"{compute_capability[0]}.{compute_capability[1]}. " + "You can use float16 instead by explicitly setting the" + "`dtype` flag in CLI, for example: --dtype=half.")